merge: 合并 upstream/main

This commit is contained in:
song
2026-01-02 10:22:29 +08:00
135 changed files with 26830 additions and 1442 deletions

1
.gitignore vendored
View File

@@ -77,6 +77,7 @@ temp/
*.temp *.temp
*.log *.log
*.bak *.bak
.cache/
# =================== # ===================
# 构建产物 # 构建产物

View File

@@ -87,9 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher() claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
@@ -99,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.NewConcurrencyService(concurrencyCache) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
@@ -116,7 +117,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
systemHandler := handler.ProvideSystemHandler(updateService) systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.NewPricingRemoteClient() pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil { if err != nil {
@@ -127,10 +132,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService() timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)

View File

@@ -25,6 +25,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
stdsql "database/sql" stdsql "database/sql"
@@ -55,6 +57,10 @@ type Client struct {
User *UserClient User *UserClient
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders. // UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
UserAllowedGroup *UserAllowedGroupClient UserAllowedGroup *UserAllowedGroupClient
// UserAttributeDefinition is the client for interacting with the UserAttributeDefinition builders.
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserSubscription is the client for interacting with the UserSubscription builders. // UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient UserSubscription *UserSubscriptionClient
} }
@@ -78,6 +84,8 @@ func (c *Client) init() {
c.UsageLog = NewUsageLogClient(c.config) c.UsageLog = NewUsageLogClient(c.config)
c.User = NewUserClient(c.config) c.User = NewUserClient(c.config)
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config) c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
c.UserSubscription = NewUserSubscriptionClient(c.config) c.UserSubscription = NewUserSubscriptionClient(c.config)
} }
@@ -169,19 +177,21 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config cfg := c.config
cfg.driver = tx cfg.driver = tx
return &Tx{ return &Tx{
ctx: ctx, ctx: ctx,
config: cfg, config: cfg,
Account: NewAccountClient(cfg), Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg), AccountGroup: NewAccountGroupClient(cfg),
ApiKey: NewApiKeyClient(cfg), ApiKey: NewApiKeyClient(cfg),
Group: NewGroupClient(cfg), Group: NewGroupClient(cfg),
Proxy: NewProxyClient(cfg), Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg), RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg), Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg), UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg), User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg), UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil }, nil
} }
@@ -199,19 +209,21 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver} cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{ return &Tx{
ctx: ctx, ctx: ctx,
config: cfg, config: cfg,
Account: NewAccountClient(cfg), Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg), AccountGroup: NewAccountGroupClient(cfg),
ApiKey: NewApiKeyClient(cfg), ApiKey: NewApiKeyClient(cfg),
Group: NewGroupClient(cfg), Group: NewGroupClient(cfg),
Proxy: NewProxyClient(cfg), Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg), RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg), Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg), UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg), User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg), UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil }, nil
} }
@@ -242,7 +254,8 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) { func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{ for _, n := range []interface{ Use(...Hook) }{
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting, c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription, c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
c.UserAttributeValue, c.UserSubscription,
} { } {
n.Use(hooks...) n.Use(hooks...)
} }
@@ -253,7 +266,8 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) { func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{ for _, n := range []interface{ Intercept(...Interceptor) }{
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting, c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription, c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
c.UserAttributeValue, c.UserSubscription,
} { } {
n.Intercept(interceptors...) n.Intercept(interceptors...)
} }
@@ -282,6 +296,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.User.mutate(ctx, m) return c.User.mutate(ctx, m)
case *UserAllowedGroupMutation: case *UserAllowedGroupMutation:
return c.UserAllowedGroup.mutate(ctx, m) return c.UserAllowedGroup.mutate(ctx, m)
case *UserAttributeDefinitionMutation:
return c.UserAttributeDefinition.mutate(ctx, m)
case *UserAttributeValueMutation:
return c.UserAttributeValue.mutate(ctx, m)
case *UserSubscriptionMutation: case *UserSubscriptionMutation:
return c.UserSubscription.mutate(ctx, m) return c.UserSubscription.mutate(ctx, m)
default: default:
@@ -1916,6 +1934,22 @@ func (c *UserClient) QueryUsageLogs(_m *User) *UsageLogQuery {
return query return query
} }
// QueryAttributeValues queries the attribute_values edge of a User.
func (c *UserClient) QueryAttributeValues(_m *User) *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AttributeValuesTable, user.AttributeValuesColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User. // QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query() query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -2075,6 +2109,322 @@ func (c *UserAllowedGroupClient) mutate(ctx context.Context, m *UserAllowedGroup
} }
} }
// UserAttributeDefinitionClient is a client for the UserAttributeDefinition schema.
type UserAttributeDefinitionClient struct {
config
}
// NewUserAttributeDefinitionClient returns a client for the UserAttributeDefinition from the given config.
func NewUserAttributeDefinitionClient(c config) *UserAttributeDefinitionClient {
return &UserAttributeDefinitionClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `userattributedefinition.Hooks(f(g(h())))`.
func (c *UserAttributeDefinitionClient) Use(hooks ...Hook) {
c.hooks.UserAttributeDefinition = append(c.hooks.UserAttributeDefinition, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `userattributedefinition.Intercept(f(g(h())))`.
func (c *UserAttributeDefinitionClient) Intercept(interceptors ...Interceptor) {
c.inters.UserAttributeDefinition = append(c.inters.UserAttributeDefinition, interceptors...)
}
// Create returns a builder for creating a UserAttributeDefinition entity.
func (c *UserAttributeDefinitionClient) Create() *UserAttributeDefinitionCreate {
mutation := newUserAttributeDefinitionMutation(c.config, OpCreate)
return &UserAttributeDefinitionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UserAttributeDefinition entities.
func (c *UserAttributeDefinitionClient) CreateBulk(builders ...*UserAttributeDefinitionCreate) *UserAttributeDefinitionCreateBulk {
return &UserAttributeDefinitionCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserAttributeDefinitionClient) MapCreateBulk(slice any, setFunc func(*UserAttributeDefinitionCreate, int)) *UserAttributeDefinitionCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserAttributeDefinitionCreateBulk{err: fmt.Errorf("calling to UserAttributeDefinitionClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserAttributeDefinitionCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserAttributeDefinitionCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) Update() *UserAttributeDefinitionUpdate {
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdate)
return &UserAttributeDefinitionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserAttributeDefinitionClient) UpdateOne(_m *UserAttributeDefinition) *UserAttributeDefinitionUpdateOne {
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdateOne, withUserAttributeDefinition(_m))
return &UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserAttributeDefinitionClient) UpdateOneID(id int64) *UserAttributeDefinitionUpdateOne {
mutation := newUserAttributeDefinitionMutation(c.config, OpUpdateOne, withUserAttributeDefinitionID(id))
return &UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) Delete() *UserAttributeDefinitionDelete {
mutation := newUserAttributeDefinitionMutation(c.config, OpDelete)
return &UserAttributeDefinitionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserAttributeDefinitionClient) DeleteOne(_m *UserAttributeDefinition) *UserAttributeDefinitionDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserAttributeDefinitionClient) DeleteOneID(id int64) *UserAttributeDefinitionDeleteOne {
builder := c.Delete().Where(userattributedefinition.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserAttributeDefinitionDeleteOne{builder}
}
// Query returns a query builder for UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) Query() *UserAttributeDefinitionQuery {
return &UserAttributeDefinitionQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUserAttributeDefinition},
inters: c.Interceptors(),
}
}
// Get returns a UserAttributeDefinition entity by its id.
func (c *UserAttributeDefinitionClient) Get(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
return c.Query().Where(userattributedefinition.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserAttributeDefinitionClient) GetX(ctx context.Context, id int64) *UserAttributeDefinition {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryValues queries the values edge of a UserAttributeDefinition.
func (c *UserAttributeDefinitionClient) QueryValues(_m *UserAttributeDefinition) *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userattributedefinition.Table, userattributedefinition.FieldID, id),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, userattributedefinition.ValuesTable, userattributedefinition.ValuesColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserAttributeDefinitionClient) Hooks() []Hook {
hooks := c.hooks.UserAttributeDefinition
return append(hooks[:len(hooks):len(hooks)], userattributedefinition.Hooks[:]...)
}
// Interceptors returns the client interceptors.
func (c *UserAttributeDefinitionClient) Interceptors() []Interceptor {
inters := c.inters.UserAttributeDefinition
return append(inters[:len(inters):len(inters)], userattributedefinition.Interceptors[:]...)
}
func (c *UserAttributeDefinitionClient) mutate(ctx context.Context, m *UserAttributeDefinitionMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserAttributeDefinitionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserAttributeDefinitionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserAttributeDefinitionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserAttributeDefinitionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UserAttributeDefinition mutation op: %q", m.Op())
}
}
// UserAttributeValueClient is a client for the UserAttributeValue schema.
type UserAttributeValueClient struct {
config
}
// NewUserAttributeValueClient returns a client for the UserAttributeValue from the given config.
func NewUserAttributeValueClient(c config) *UserAttributeValueClient {
return &UserAttributeValueClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `userattributevalue.Hooks(f(g(h())))`.
func (c *UserAttributeValueClient) Use(hooks ...Hook) {
c.hooks.UserAttributeValue = append(c.hooks.UserAttributeValue, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `userattributevalue.Intercept(f(g(h())))`.
func (c *UserAttributeValueClient) Intercept(interceptors ...Interceptor) {
c.inters.UserAttributeValue = append(c.inters.UserAttributeValue, interceptors...)
}
// Create returns a builder for creating a UserAttributeValue entity.
func (c *UserAttributeValueClient) Create() *UserAttributeValueCreate {
mutation := newUserAttributeValueMutation(c.config, OpCreate)
return &UserAttributeValueCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UserAttributeValue entities.
func (c *UserAttributeValueClient) CreateBulk(builders ...*UserAttributeValueCreate) *UserAttributeValueCreateBulk {
return &UserAttributeValueCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserAttributeValueClient) MapCreateBulk(slice any, setFunc func(*UserAttributeValueCreate, int)) *UserAttributeValueCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserAttributeValueCreateBulk{err: fmt.Errorf("calling to UserAttributeValueClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserAttributeValueCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserAttributeValueCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UserAttributeValue.
func (c *UserAttributeValueClient) Update() *UserAttributeValueUpdate {
mutation := newUserAttributeValueMutation(c.config, OpUpdate)
return &UserAttributeValueUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserAttributeValueClient) UpdateOne(_m *UserAttributeValue) *UserAttributeValueUpdateOne {
mutation := newUserAttributeValueMutation(c.config, OpUpdateOne, withUserAttributeValue(_m))
return &UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserAttributeValueClient) UpdateOneID(id int64) *UserAttributeValueUpdateOne {
mutation := newUserAttributeValueMutation(c.config, OpUpdateOne, withUserAttributeValueID(id))
return &UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UserAttributeValue.
func (c *UserAttributeValueClient) Delete() *UserAttributeValueDelete {
mutation := newUserAttributeValueMutation(c.config, OpDelete)
return &UserAttributeValueDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserAttributeValueClient) DeleteOne(_m *UserAttributeValue) *UserAttributeValueDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserAttributeValueClient) DeleteOneID(id int64) *UserAttributeValueDeleteOne {
builder := c.Delete().Where(userattributevalue.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserAttributeValueDeleteOne{builder}
}
// Query returns a query builder for UserAttributeValue.
func (c *UserAttributeValueClient) Query() *UserAttributeValueQuery {
return &UserAttributeValueQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUserAttributeValue},
inters: c.Interceptors(),
}
}
// Get returns a UserAttributeValue entity by its id.
func (c *UserAttributeValueClient) Get(ctx context.Context, id int64) (*UserAttributeValue, error) {
return c.Query().Where(userattributevalue.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserAttributeValueClient) GetX(ctx context.Context, id int64) *UserAttributeValue {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryUser queries the user edge of a UserAttributeValue.
func (c *UserAttributeValueClient) QueryUser(_m *UserAttributeValue) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.UserTable, userattributevalue.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryDefinition queries the definition edge of a UserAttributeValue.
func (c *UserAttributeValueClient) QueryDefinition(_m *UserAttributeValue) *UserAttributeDefinitionQuery {
query := (&UserAttributeDefinitionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, id),
sqlgraph.To(userattributedefinition.Table, userattributedefinition.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.DefinitionTable, userattributevalue.DefinitionColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserAttributeValueClient) Hooks() []Hook {
return c.hooks.UserAttributeValue
}
// Interceptors returns the client interceptors.
func (c *UserAttributeValueClient) Interceptors() []Interceptor {
return c.inters.UserAttributeValue
}
func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeValueMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserAttributeValueCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserAttributeValueUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserAttributeValueUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserAttributeValueDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UserAttributeValue mutation op: %q", m.Op())
}
}
// UserSubscriptionClient is a client for the UserSubscription schema. // UserSubscriptionClient is a client for the UserSubscription schema.
type UserSubscriptionClient struct { type UserSubscriptionClient struct {
config config
@@ -2278,11 +2628,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type ( type (
hooks struct { hooks struct {
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog, Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
User, UserAllowedGroup, UserSubscription []ent.Hook User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Hook
} }
inters struct { inters struct {
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog, Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
User, UserAllowedGroup, UserSubscription []ent.Interceptor User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Interceptor
} }
) )

View File

@@ -22,6 +22,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
@@ -83,17 +85,19 @@ var (
func checkColumn(t, c string) error { func checkColumn(t, c string) error {
initCheck.Do(func() { initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
account.Table: account.ValidColumn, account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn, accountgroup.Table: accountgroup.ValidColumn,
apikey.Table: apikey.ValidColumn, apikey.Table: apikey.ValidColumn,
group.Table: group.ValidColumn, group.Table: group.ValidColumn,
proxy.Table: proxy.ValidColumn, proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn, redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn, setting.Table: setting.ValidColumn,
usagelog.Table: usagelog.ValidColumn, usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn, user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn, userallowedgroup.Table: userallowedgroup.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn, userattributedefinition.Table: userattributedefinition.ValidColumn,
userattributevalue.Table: userattributevalue.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
}) })
}) })
return columnCheck(t, c) return columnCheck(t, c)

View File

@@ -129,6 +129,30 @@ func (f UserAllowedGroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAllowedGroupMutation", m) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAllowedGroupMutation", m)
} }
// The UserAttributeDefinitionFunc type is an adapter to allow the use of ordinary
// function as UserAttributeDefinition mutator.
type UserAttributeDefinitionFunc func(context.Context, *ent.UserAttributeDefinitionMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserAttributeDefinitionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserAttributeDefinitionMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeDefinitionMutation", m)
}
// The UserAttributeValueFunc type is an adapter to allow the use of ordinary
// function as UserAttributeValue mutator.
type UserAttributeValueFunc func(context.Context, *ent.UserAttributeValueMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserAttributeValueMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary // The UserSubscriptionFunc type is an adapter to allow the use of ordinary
// function as UserSubscription mutator. // function as UserSubscription mutator.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error) type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)

View File

@@ -19,6 +19,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
@@ -348,6 +350,60 @@ func (f TraverseUserAllowedGroup) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.UserAllowedGroupQuery", q) return fmt.Errorf("unexpected query type %T. expect *ent.UserAllowedGroupQuery", q)
} }
// The UserAttributeDefinitionFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserAttributeDefinitionFunc func(context.Context, *ent.UserAttributeDefinitionQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserAttributeDefinitionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserAttributeDefinitionQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeDefinitionQuery", q)
}
// The TraverseUserAttributeDefinition type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUserAttributeDefinition func(context.Context, *ent.UserAttributeDefinitionQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUserAttributeDefinition) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUserAttributeDefinition) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserAttributeDefinitionQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeDefinitionQuery", q)
}
// The UserAttributeValueFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserAttributeValueFunc func(context.Context, *ent.UserAttributeValueQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserAttributeValueFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserAttributeValueQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
}
// The TraverseUserAttributeValue type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUserAttributeValue func(context.Context, *ent.UserAttributeValueQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUserAttributeValue) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserAttributeValueQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier. // The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error) type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
@@ -398,6 +454,10 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil
case *ent.UserAllowedGroupQuery: case *ent.UserAllowedGroupQuery:
return &query[*ent.UserAllowedGroupQuery, predicate.UserAllowedGroup, userallowedgroup.OrderOption]{typ: ent.TypeUserAllowedGroup, tq: q}, nil return &query[*ent.UserAllowedGroupQuery, predicate.UserAllowedGroup, userallowedgroup.OrderOption]{typ: ent.TypeUserAllowedGroup, tq: q}, nil
case *ent.UserAttributeDefinitionQuery:
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
case *ent.UserAttributeValueQuery:
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
case *ent.UserSubscriptionQuery: case *ent.UserSubscriptionQuery:
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
default: default:

View File

@@ -477,7 +477,6 @@ var (
{Name: "concurrency", Type: field.TypeInt, Default: 5}, {Name: "concurrency", Type: field.TypeInt, Default: 5},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""}, {Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "wechat", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
} }
// UsersTable holds the schema information for the "users" table. // UsersTable holds the schema information for the "users" table.
@@ -531,6 +530,92 @@ var (
}, },
}, },
} }
// UserAttributeDefinitionsColumns holds the columns for the "user_attribute_definitions" table.
UserAttributeDefinitionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "key", Type: field.TypeString, Size: 100},
{Name: "name", Type: field.TypeString, Size: 255},
{Name: "description", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "type", Type: field.TypeString, Size: 20},
{Name: "options", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "required", Type: field.TypeBool, Default: false},
{Name: "validation", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "placeholder", Type: field.TypeString, Size: 255, Default: ""},
{Name: "display_order", Type: field.TypeInt, Default: 0},
{Name: "enabled", Type: field.TypeBool, Default: true},
}
// UserAttributeDefinitionsTable holds the schema information for the "user_attribute_definitions" table.
UserAttributeDefinitionsTable = &schema.Table{
Name: "user_attribute_definitions",
Columns: UserAttributeDefinitionsColumns,
PrimaryKey: []*schema.Column{UserAttributeDefinitionsColumns[0]},
Indexes: []*schema.Index{
{
Name: "userattributedefinition_key",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[4]},
},
{
Name: "userattributedefinition_enabled",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[13]},
},
{
Name: "userattributedefinition_display_order",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[12]},
},
{
Name: "userattributedefinition_deleted_at",
Unique: false,
Columns: []*schema.Column{UserAttributeDefinitionsColumns[3]},
},
},
}
// UserAttributeValuesColumns holds the columns for the "user_attribute_values" table.
UserAttributeValuesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "value", Type: field.TypeString, Size: 2147483647, Default: ""},
{Name: "user_id", Type: field.TypeInt64},
{Name: "attribute_id", Type: field.TypeInt64},
}
// UserAttributeValuesTable holds the schema information for the "user_attribute_values" table.
UserAttributeValuesTable = &schema.Table{
Name: "user_attribute_values",
Columns: UserAttributeValuesColumns,
PrimaryKey: []*schema.Column{UserAttributeValuesColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "user_attribute_values_users_attribute_values",
Columns: []*schema.Column{UserAttributeValuesColumns[4]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "user_attribute_values_user_attribute_definitions_values",
Columns: []*schema.Column{UserAttributeValuesColumns[5]},
RefColumns: []*schema.Column{UserAttributeDefinitionsColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "userattributevalue_user_id_attribute_id",
Unique: true,
Columns: []*schema.Column{UserAttributeValuesColumns[4], UserAttributeValuesColumns[5]},
},
{
Name: "userattributevalue_attribute_id",
Unique: false,
Columns: []*schema.Column{UserAttributeValuesColumns[5]},
},
},
}
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table. // UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
UserSubscriptionsColumns = []*schema.Column{ UserSubscriptionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "id", Type: field.TypeInt64, Increment: true},
@@ -627,6 +712,8 @@ var (
UsageLogsTable, UsageLogsTable,
UsersTable, UsersTable,
UserAllowedGroupsTable, UserAllowedGroupsTable,
UserAttributeDefinitionsTable,
UserAttributeValuesTable,
UserSubscriptionsTable, UserSubscriptionsTable,
} }
) )
@@ -676,6 +763,14 @@ func init() {
UserAllowedGroupsTable.Annotation = &entsql.Annotation{ UserAllowedGroupsTable.Annotation = &entsql.Annotation{
Table: "user_allowed_groups", Table: "user_allowed_groups",
} }
UserAttributeDefinitionsTable.Annotation = &entsql.Annotation{
Table: "user_attribute_definitions",
}
UserAttributeValuesTable.ForeignKeys[0].RefTable = UsersTable
UserAttributeValuesTable.ForeignKeys[1].RefTable = UserAttributeDefinitionsTable
UserAttributeValuesTable.Annotation = &entsql.Annotation{
Table: "user_attribute_values",
}
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable

File diff suppressed because it is too large Load Diff

View File

@@ -36,5 +36,11 @@ type User func(*sql.Selector)
// UserAllowedGroup is the predicate function for userallowedgroup builders. // UserAllowedGroup is the predicate function for userallowedgroup builders.
type UserAllowedGroup func(*sql.Selector) type UserAllowedGroup func(*sql.Selector)
// UserAttributeDefinition is the predicate function for userattributedefinition builders.
type UserAttributeDefinition func(*sql.Selector)
// UserAttributeValue is the predicate function for userattributevalue builders.
type UserAttributeValue func(*sql.Selector)
// UserSubscription is the predicate function for usersubscription builders. // UserSubscription is the predicate function for usersubscription builders.
type UserSubscription func(*sql.Selector) type UserSubscription func(*sql.Selector)

View File

@@ -16,6 +16,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
@@ -604,14 +606,8 @@ func init() {
user.DefaultUsername = userDescUsername.Default.(string) user.DefaultUsername = userDescUsername.Default.(string)
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save. // user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error) user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
// userDescWechat is the schema descriptor for wechat field.
userDescWechat := userFields[7].Descriptor()
// user.DefaultWechat holds the default value on creation for the wechat field.
user.DefaultWechat = userDescWechat.Default.(string)
// user.WechatValidator is a validator for the "wechat" field. It is called by the builders before save.
user.WechatValidator = userDescWechat.Validators[0].(func(string) error)
// userDescNotes is the schema descriptor for notes field. // userDescNotes is the schema descriptor for notes field.
userDescNotes := userFields[8].Descriptor() userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field. // user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string) user.DefaultNotes = userDescNotes.Default.(string)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields() userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
@@ -620,6 +616,128 @@ func init() {
userallowedgroupDescCreatedAt := userallowedgroupFields[2].Descriptor() userallowedgroupDescCreatedAt := userallowedgroupFields[2].Descriptor()
// userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field. // userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time) userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time)
userattributedefinitionMixin := schema.UserAttributeDefinition{}.Mixin()
userattributedefinitionMixinHooks1 := userattributedefinitionMixin[1].Hooks()
userattributedefinition.Hooks[0] = userattributedefinitionMixinHooks1[0]
userattributedefinitionMixinInters1 := userattributedefinitionMixin[1].Interceptors()
userattributedefinition.Interceptors[0] = userattributedefinitionMixinInters1[0]
userattributedefinitionMixinFields0 := userattributedefinitionMixin[0].Fields()
_ = userattributedefinitionMixinFields0
userattributedefinitionFields := schema.UserAttributeDefinition{}.Fields()
_ = userattributedefinitionFields
// userattributedefinitionDescCreatedAt is the schema descriptor for created_at field.
userattributedefinitionDescCreatedAt := userattributedefinitionMixinFields0[0].Descriptor()
// userattributedefinition.DefaultCreatedAt holds the default value on creation for the created_at field.
userattributedefinition.DefaultCreatedAt = userattributedefinitionDescCreatedAt.Default.(func() time.Time)
// userattributedefinitionDescUpdatedAt is the schema descriptor for updated_at field.
userattributedefinitionDescUpdatedAt := userattributedefinitionMixinFields0[1].Descriptor()
// userattributedefinition.DefaultUpdatedAt holds the default value on creation for the updated_at field.
userattributedefinition.DefaultUpdatedAt = userattributedefinitionDescUpdatedAt.Default.(func() time.Time)
// userattributedefinition.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
userattributedefinition.UpdateDefaultUpdatedAt = userattributedefinitionDescUpdatedAt.UpdateDefault.(func() time.Time)
// userattributedefinitionDescKey is the schema descriptor for key field.
userattributedefinitionDescKey := userattributedefinitionFields[0].Descriptor()
// userattributedefinition.KeyValidator is a validator for the "key" field. It is called by the builders before save.
userattributedefinition.KeyValidator = func() func(string) error {
validators := userattributedefinitionDescKey.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(key string) error {
for _, fn := range fns {
if err := fn(key); err != nil {
return err
}
}
return nil
}
}()
// userattributedefinitionDescName is the schema descriptor for name field.
userattributedefinitionDescName := userattributedefinitionFields[1].Descriptor()
// userattributedefinition.NameValidator is a validator for the "name" field. It is called by the builders before save.
userattributedefinition.NameValidator = func() func(string) error {
validators := userattributedefinitionDescName.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(name string) error {
for _, fn := range fns {
if err := fn(name); err != nil {
return err
}
}
return nil
}
}()
// userattributedefinitionDescDescription is the schema descriptor for description field.
userattributedefinitionDescDescription := userattributedefinitionFields[2].Descriptor()
// userattributedefinition.DefaultDescription holds the default value on creation for the description field.
userattributedefinition.DefaultDescription = userattributedefinitionDescDescription.Default.(string)
// userattributedefinitionDescType is the schema descriptor for type field.
userattributedefinitionDescType := userattributedefinitionFields[3].Descriptor()
// userattributedefinition.TypeValidator is a validator for the "type" field. It is called by the builders before save.
userattributedefinition.TypeValidator = func() func(string) error {
validators := userattributedefinitionDescType.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(_type string) error {
for _, fn := range fns {
if err := fn(_type); err != nil {
return err
}
}
return nil
}
}()
// userattributedefinitionDescOptions is the schema descriptor for options field.
userattributedefinitionDescOptions := userattributedefinitionFields[4].Descriptor()
// userattributedefinition.DefaultOptions holds the default value on creation for the options field.
userattributedefinition.DefaultOptions = userattributedefinitionDescOptions.Default.([]map[string]interface{})
// userattributedefinitionDescRequired is the schema descriptor for required field.
userattributedefinitionDescRequired := userattributedefinitionFields[5].Descriptor()
// userattributedefinition.DefaultRequired holds the default value on creation for the required field.
userattributedefinition.DefaultRequired = userattributedefinitionDescRequired.Default.(bool)
// userattributedefinitionDescValidation is the schema descriptor for validation field.
userattributedefinitionDescValidation := userattributedefinitionFields[6].Descriptor()
// userattributedefinition.DefaultValidation holds the default value on creation for the validation field.
userattributedefinition.DefaultValidation = userattributedefinitionDescValidation.Default.(map[string]interface{})
// userattributedefinitionDescPlaceholder is the schema descriptor for placeholder field.
userattributedefinitionDescPlaceholder := userattributedefinitionFields[7].Descriptor()
// userattributedefinition.DefaultPlaceholder holds the default value on creation for the placeholder field.
userattributedefinition.DefaultPlaceholder = userattributedefinitionDescPlaceholder.Default.(string)
// userattributedefinition.PlaceholderValidator is a validator for the "placeholder" field. It is called by the builders before save.
userattributedefinition.PlaceholderValidator = userattributedefinitionDescPlaceholder.Validators[0].(func(string) error)
// userattributedefinitionDescDisplayOrder is the schema descriptor for display_order field.
userattributedefinitionDescDisplayOrder := userattributedefinitionFields[8].Descriptor()
// userattributedefinition.DefaultDisplayOrder holds the default value on creation for the display_order field.
userattributedefinition.DefaultDisplayOrder = userattributedefinitionDescDisplayOrder.Default.(int)
// userattributedefinitionDescEnabled is the schema descriptor for enabled field.
userattributedefinitionDescEnabled := userattributedefinitionFields[9].Descriptor()
// userattributedefinition.DefaultEnabled holds the default value on creation for the enabled field.
userattributedefinition.DefaultEnabled = userattributedefinitionDescEnabled.Default.(bool)
userattributevalueMixin := schema.UserAttributeValue{}.Mixin()
userattributevalueMixinFields0 := userattributevalueMixin[0].Fields()
_ = userattributevalueMixinFields0
userattributevalueFields := schema.UserAttributeValue{}.Fields()
_ = userattributevalueFields
// userattributevalueDescCreatedAt is the schema descriptor for created_at field.
userattributevalueDescCreatedAt := userattributevalueMixinFields0[0].Descriptor()
// userattributevalue.DefaultCreatedAt holds the default value on creation for the created_at field.
userattributevalue.DefaultCreatedAt = userattributevalueDescCreatedAt.Default.(func() time.Time)
// userattributevalueDescUpdatedAt is the schema descriptor for updated_at field.
userattributevalueDescUpdatedAt := userattributevalueMixinFields0[1].Descriptor()
// userattributevalue.DefaultUpdatedAt holds the default value on creation for the updated_at field.
userattributevalue.DefaultUpdatedAt = userattributevalueDescUpdatedAt.Default.(func() time.Time)
// userattributevalue.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
userattributevalue.UpdateDefaultUpdatedAt = userattributevalueDescUpdatedAt.UpdateDefault.(func() time.Time)
// userattributevalueDescValue is the schema descriptor for value field.
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
// userattributevalue.DefaultValue holds the default value on creation for the value field.
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
usersubscriptionMixin := schema.UserSubscription{}.Mixin() usersubscriptionMixin := schema.UserSubscription{}.Mixin()
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks() usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0] usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]

View File

@@ -57,9 +57,7 @@ func (User) Fields() []ent.Field {
field.String("username"). field.String("username").
MaxLen(100). MaxLen(100).
Default(""), Default(""),
field.String("wechat"). // wechat field migrated to user_attribute_values (see migration 019)
MaxLen(100).
Default(""),
field.String("notes"). field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}). SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""), Default(""),
@@ -75,6 +73,7 @@ func (User) Edges() []ent.Edge {
edge.To("allowed_groups", Group.Type). edge.To("allowed_groups", Group.Type).
Through("user_allowed_groups", UserAllowedGroup.Type), Through("user_allowed_groups", UserAllowedGroup.Type),
edge.To("usage_logs", UsageLog.Type), edge.To("usage_logs", UsageLog.Type),
edge.To("attribute_values", UserAttributeValue.Type),
} }
} }

View File

@@ -0,0 +1,109 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UserAttributeDefinition holds the schema definition for custom user attributes.
//
// This entity defines the metadata for user attributes, such as:
// - Attribute key (unique identifier like "company_name")
// - Display name shown in forms
// - Field type (text, number, select, etc.)
// - Validation rules
// - Whether the field is required or enabled
type UserAttributeDefinition struct {
ent.Schema
}
func (UserAttributeDefinition) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "user_attribute_definitions"},
}
}
func (UserAttributeDefinition) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
mixins.SoftDeleteMixin{},
}
}
func (UserAttributeDefinition) Fields() []ent.Field {
return []ent.Field{
// key: Unique identifier for the attribute (e.g., "company_name")
// Used for programmatic reference
field.String("key").
MaxLen(100).
NotEmpty(),
// name: Display name shown in forms (e.g., "Company Name")
field.String("name").
MaxLen(255).
NotEmpty(),
// description: Optional description/help text for the attribute
field.String("description").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
// type: Attribute type - text, textarea, number, email, url, date, select, multi_select
field.String("type").
MaxLen(20).
NotEmpty(),
// options: Select options for select/multi_select types (stored as JSONB)
// Format: [{"value": "xxx", "label": "XXX"}, ...]
field.JSON("options", []map[string]any{}).
Default([]map[string]any{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// required: Whether this attribute is required when editing a user
field.Bool("required").
Default(false),
// validation: Validation rules for the attribute value (stored as JSONB)
// Format: {"min_length": 1, "max_length": 100, "min": 0, "max": 100, "pattern": "^[a-z]+$", "message": "..."}
field.JSON("validation", map[string]any{}).
Default(map[string]any{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// placeholder: Placeholder text shown in input fields
field.String("placeholder").
MaxLen(255).
Default(""),
// display_order: Order in which attributes are displayed (lower = first)
field.Int("display_order").
Default(0),
// enabled: Whether this attribute is active and shown in forms
field.Bool("enabled").
Default(true),
}
}
func (UserAttributeDefinition) Edges() []ent.Edge {
return []ent.Edge{
// values: All user values for this attribute definition
edge.To("values", UserAttributeValue.Type),
}
}
func (UserAttributeDefinition) Indexes() []ent.Index {
return []ent.Index{
// Partial unique index on key (WHERE deleted_at IS NULL) via migration
index.Fields("key"),
index.Fields("enabled"),
index.Fields("display_order"),
index.Fields("deleted_at"),
}
}

View File

@@ -0,0 +1,74 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UserAttributeValue holds a user's value for a specific attribute.
//
// This entity stores the actual values that users have for each attribute definition.
// Values are stored as strings and converted to the appropriate type by the application.
type UserAttributeValue struct {
ent.Schema
}
func (UserAttributeValue) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "user_attribute_values"},
}
}
func (UserAttributeValue) Mixin() []ent.Mixin {
return []ent.Mixin{
// Only use TimeMixin, no soft delete - values are hard deleted
mixins.TimeMixin{},
}
}
func (UserAttributeValue) Fields() []ent.Field {
return []ent.Field{
// user_id: References the user this value belongs to
field.Int64("user_id"),
// attribute_id: References the attribute definition
field.Int64("attribute_id"),
// value: The actual value stored as a string
// For multi_select, this is a JSON array string
field.Text("value").
Default(""),
}
}
func (UserAttributeValue) Edges() []ent.Edge {
return []ent.Edge{
// user: The user who owns this attribute value
edge.From("user", User.Type).
Ref("attribute_values").
Field("user_id").
Required().
Unique(),
// definition: The attribute definition this value is for
edge.From("definition", UserAttributeDefinition.Type).
Ref("values").
Field("attribute_id").
Required().
Unique(),
}
}
func (UserAttributeValue) Indexes() []ent.Index {
return []ent.Index{
// Unique index on (user_id, attribute_id)
index.Fields("user_id", "attribute_id").Unique(),
index.Fields("attribute_id"),
}
}

View File

@@ -34,6 +34,10 @@ type Tx struct {
User *UserClient User *UserClient
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders. // UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
UserAllowedGroup *UserAllowedGroupClient UserAllowedGroup *UserAllowedGroupClient
// UserAttributeDefinition is the client for interacting with the UserAttributeDefinition builders.
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserSubscription is the client for interacting with the UserSubscription builders. // UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient UserSubscription *UserSubscriptionClient
@@ -177,6 +181,8 @@ func (tx *Tx) init() {
tx.UsageLog = NewUsageLogClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config) tx.User = NewUserClient(tx.config)
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config) tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
tx.UserSubscription = NewUserSubscriptionClient(tx.config) tx.UserSubscription = NewUserSubscriptionClient(tx.config)
} }

View File

@@ -37,8 +37,6 @@ type User struct {
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// Username holds the value of the "username" field. // Username holds the value of the "username" field.
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
// Wechat holds the value of the "wechat" field.
Wechat string `json:"wechat,omitempty"`
// Notes holds the value of the "notes" field. // Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"` Notes string `json:"notes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
@@ -61,11 +59,13 @@ type UserEdges struct {
AllowedGroups []*Group `json:"allowed_groups,omitempty"` AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge. // UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"` UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// AttributeValues holds the value of the attribute_values edge.
AttributeValues []*UserAttributeValue `json:"attribute_values,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge. // UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a // loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not. // type was loaded (or requested) in eager-loading or not.
loadedTypes [7]bool loadedTypes [8]bool
} }
// APIKeysOrErr returns the APIKeys value or an error if the edge // APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -122,10 +122,19 @@ func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
return nil, &NotLoadedError{edge: "usage_logs"} return nil, &NotLoadedError{edge: "usage_logs"}
} }
// AttributeValuesOrErr returns the AttributeValues value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
if e.loadedTypes[6] {
return e.AttributeValues, nil
}
return nil, &NotLoadedError{edge: "attribute_values"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading. // was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[6] { if e.loadedTypes[7] {
return e.UserAllowedGroups, nil return e.UserAllowedGroups, nil
} }
return nil, &NotLoadedError{edge: "user_allowed_groups"} return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -140,7 +149,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency: case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldWechat, user.FieldNotes: case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt: case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@@ -226,12 +235,6 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Username = value.String _m.Username = value.String
} }
case user.FieldWechat:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field wechat", values[i])
} else if value.Valid {
_m.Wechat = value.String
}
case user.FieldNotes: case user.FieldNotes:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field notes", values[i]) return fmt.Errorf("unexpected type %T for field notes", values[i])
@@ -281,6 +284,11 @@ func (_m *User) QueryUsageLogs() *UsageLogQuery {
return NewUserClient(_m.config).QueryUsageLogs(_m) return NewUserClient(_m.config).QueryUsageLogs(_m)
} }
// QueryAttributeValues queries the "attribute_values" edge of the User entity.
func (_m *User) QueryAttributeValues() *UserAttributeValueQuery {
return NewUserClient(_m.config).QueryAttributeValues(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m) return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -341,9 +349,6 @@ func (_m *User) String() string {
builder.WriteString("username=") builder.WriteString("username=")
builder.WriteString(_m.Username) builder.WriteString(_m.Username)
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("wechat=")
builder.WriteString(_m.Wechat)
builder.WriteString(", ")
builder.WriteString("notes=") builder.WriteString("notes=")
builder.WriteString(_m.Notes) builder.WriteString(_m.Notes)
builder.WriteByte(')') builder.WriteByte(')')

View File

@@ -35,8 +35,6 @@ const (
FieldStatus = "status" FieldStatus = "status"
// FieldUsername holds the string denoting the username field in the database. // FieldUsername holds the string denoting the username field in the database.
FieldUsername = "username" FieldUsername = "username"
// FieldWechat holds the string denoting the wechat field in the database.
FieldWechat = "wechat"
// FieldNotes holds the string denoting the notes field in the database. // FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes" FieldNotes = "notes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
@@ -51,6 +49,8 @@ const (
EdgeAllowedGroups = "allowed_groups" EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs" EdgeUsageLogs = "usage_logs"
// EdgeAttributeValues holds the string denoting the attribute_values edge name in mutations.
EdgeAttributeValues = "attribute_values"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups" EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database. // Table holds the table name of the user in the database.
@@ -95,6 +95,13 @@ const (
UsageLogsInverseTable = "usage_logs" UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge. // UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "user_id" UsageLogsColumn = "user_id"
// AttributeValuesTable is the table that holds the attribute_values relation/edge.
AttributeValuesTable = "user_attribute_values"
// AttributeValuesInverseTable is the table name for the UserAttributeValue entity.
// It exists in this package in order to avoid circular dependency with the "userattributevalue" package.
AttributeValuesInverseTable = "user_attribute_values"
// AttributeValuesColumn is the table column denoting the attribute_values relation/edge.
AttributeValuesColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups" UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -117,7 +124,6 @@ var Columns = []string{
FieldConcurrency, FieldConcurrency,
FieldStatus, FieldStatus,
FieldUsername, FieldUsername,
FieldWechat,
FieldNotes, FieldNotes,
} }
@@ -171,10 +177,6 @@ var (
DefaultUsername string DefaultUsername string
// UsernameValidator is a validator for the "username" field. It is called by the builders before save. // UsernameValidator is a validator for the "username" field. It is called by the builders before save.
UsernameValidator func(string) error UsernameValidator func(string) error
// DefaultWechat holds the default value on creation for the "wechat" field.
DefaultWechat string
// WechatValidator is a validator for the "wechat" field. It is called by the builders before save.
WechatValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field. // DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string DefaultNotes string
) )
@@ -237,11 +239,6 @@ func ByUsername(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsername, opts...).ToFunc() return sql.OrderByField(FieldUsername, opts...).ToFunc()
} }
// ByWechat orders the results by the wechat field.
func ByWechat(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWechat, opts...).ToFunc()
}
// ByNotes orders the results by the notes field. // ByNotes orders the results by the notes field.
func ByNotes(opts ...sql.OrderTermOption) OrderOption { func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc() return sql.OrderByField(FieldNotes, opts...).ToFunc()
@@ -331,6 +328,20 @@ func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
} }
} }
// ByAttributeValuesCount orders the results by attribute_values count.
func ByAttributeValuesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAttributeValuesStep(), opts...)
}
}
// ByAttributeValues orders the results by attribute_values terms.
func ByAttributeValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAttributeValuesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count. // ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {
@@ -386,6 +397,13 @@ func newUsageLogsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
) )
} }
func newAttributeValuesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AttributeValuesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step { func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep( return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID), sqlgraph.From(Table, FieldID),

View File

@@ -105,11 +105,6 @@ func Username(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v)) return predicate.User(sql.FieldEQ(FieldUsername, v))
} }
// Wechat applies equality check predicate on the "wechat" field. It's identical to WechatEQ.
func Wechat(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldWechat, v))
}
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ. // Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
func Notes(v string) predicate.User { func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v)) return predicate.User(sql.FieldEQ(FieldNotes, v))
@@ -650,71 +645,6 @@ func UsernameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldUsername, v)) return predicate.User(sql.FieldContainsFold(FieldUsername, v))
} }
// WechatEQ applies the EQ predicate on the "wechat" field.
func WechatEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldWechat, v))
}
// WechatNEQ applies the NEQ predicate on the "wechat" field.
func WechatNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldWechat, v))
}
// WechatIn applies the In predicate on the "wechat" field.
func WechatIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldWechat, vs...))
}
// WechatNotIn applies the NotIn predicate on the "wechat" field.
func WechatNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldWechat, vs...))
}
// WechatGT applies the GT predicate on the "wechat" field.
func WechatGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldWechat, v))
}
// WechatGTE applies the GTE predicate on the "wechat" field.
func WechatGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldWechat, v))
}
// WechatLT applies the LT predicate on the "wechat" field.
func WechatLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldWechat, v))
}
// WechatLTE applies the LTE predicate on the "wechat" field.
func WechatLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldWechat, v))
}
// WechatContains applies the Contains predicate on the "wechat" field.
func WechatContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldWechat, v))
}
// WechatHasPrefix applies the HasPrefix predicate on the "wechat" field.
func WechatHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldWechat, v))
}
// WechatHasSuffix applies the HasSuffix predicate on the "wechat" field.
func WechatHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldWechat, v))
}
// WechatEqualFold applies the EqualFold predicate on the "wechat" field.
func WechatEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldWechat, v))
}
// WechatContainsFold applies the ContainsFold predicate on the "wechat" field.
func WechatContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldWechat, v))
}
// NotesEQ applies the EQ predicate on the "notes" field. // NotesEQ applies the EQ predicate on the "notes" field.
func NotesEQ(v string) predicate.User { func NotesEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v)) return predicate.User(sql.FieldEQ(FieldNotes, v))
@@ -918,6 +848,29 @@ func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User {
}) })
} }
// HasAttributeValues applies the HasEdge predicate on the "attribute_values" edge.
func HasAttributeValues() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAttributeValuesWith applies the HasEdge predicate on the "attribute_values" edge with a given conditions (other predicates).
func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAttributeValuesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User { func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
@@ -151,20 +152,6 @@ func (_c *UserCreate) SetNillableUsername(v *string) *UserCreate {
return _c return _c
} }
// SetWechat sets the "wechat" field.
func (_c *UserCreate) SetWechat(v string) *UserCreate {
_c.mutation.SetWechat(v)
return _c
}
// SetNillableWechat sets the "wechat" field if the given value is not nil.
func (_c *UserCreate) SetNillableWechat(v *string) *UserCreate {
if v != nil {
_c.SetWechat(*v)
}
return _c
}
// SetNotes sets the "notes" field. // SetNotes sets the "notes" field.
func (_c *UserCreate) SetNotes(v string) *UserCreate { func (_c *UserCreate) SetNotes(v string) *UserCreate {
_c.mutation.SetNotes(v) _c.mutation.SetNotes(v)
@@ -269,6 +256,21 @@ func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate {
return _c.AddUsageLogIDs(ids...) return _c.AddUsageLogIDs(ids...)
} }
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
func (_c *UserCreate) AddAttributeValueIDs(ids ...int64) *UserCreate {
_c.mutation.AddAttributeValueIDs(ids...)
return _c
}
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAttributeValueIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation { func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation return _c.mutation
@@ -340,10 +342,6 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultUsername v := user.DefaultUsername
_c.mutation.SetUsername(v) _c.mutation.SetUsername(v)
} }
if _, ok := _c.mutation.Wechat(); !ok {
v := user.DefaultWechat
_c.mutation.SetWechat(v)
}
if _, ok := _c.mutation.Notes(); !ok { if _, ok := _c.mutation.Notes(); !ok {
v := user.DefaultNotes v := user.DefaultNotes
_c.mutation.SetNotes(v) _c.mutation.SetNotes(v)
@@ -405,14 +403,6 @@ func (_c *UserCreate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
} }
} }
if _, ok := _c.mutation.Wechat(); !ok {
return &ValidationError{Name: "wechat", err: errors.New(`ent: missing required field "User.wechat"`)}
}
if v, ok := _c.mutation.Wechat(); ok {
if err := user.WechatValidator(v); err != nil {
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
}
}
if _, ok := _c.mutation.Notes(); !ok { if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)} return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
} }
@@ -483,10 +473,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldUsername, field.TypeString, value) _spec.SetField(user.FieldUsername, field.TypeString, value)
_node.Username = value _node.Username = value
} }
if value, ok := _c.mutation.Wechat(); ok {
_spec.SetField(user.FieldWechat, field.TypeString, value)
_node.Wechat = value
}
if value, ok := _c.mutation.Notes(); ok { if value, ok := _c.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value) _spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value _node.Notes = value
@@ -591,6 +577,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
} }
_spec.Edges = append(_spec.Edges, edge) _spec.Edges = append(_spec.Edges, edge)
} }
if nodes := _c.mutation.AttributeValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec return _node, _spec
} }
@@ -769,18 +771,6 @@ func (u *UserUpsert) UpdateUsername() *UserUpsert {
return u return u
} }
// SetWechat sets the "wechat" field.
func (u *UserUpsert) SetWechat(v string) *UserUpsert {
u.Set(user.FieldWechat, v)
return u
}
// UpdateWechat sets the "wechat" field to the value that was provided on create.
func (u *UserUpsert) UpdateWechat() *UserUpsert {
u.SetExcluded(user.FieldWechat)
return u
}
// SetNotes sets the "notes" field. // SetNotes sets the "notes" field.
func (u *UserUpsert) SetNotes(v string) *UserUpsert { func (u *UserUpsert) SetNotes(v string) *UserUpsert {
u.Set(user.FieldNotes, v) u.Set(user.FieldNotes, v)
@@ -985,20 +975,6 @@ func (u *UserUpsertOne) UpdateUsername() *UserUpsertOne {
}) })
} }
// SetWechat sets the "wechat" field.
func (u *UserUpsertOne) SetWechat(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetWechat(v)
})
}
// UpdateWechat sets the "wechat" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateWechat() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateWechat()
})
}
// SetNotes sets the "notes" field. // SetNotes sets the "notes" field.
func (u *UserUpsertOne) SetNotes(v string) *UserUpsertOne { func (u *UserUpsertOne) SetNotes(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) { return u.Update(func(s *UserUpsert) {
@@ -1371,20 +1347,6 @@ func (u *UserUpsertBulk) UpdateUsername() *UserUpsertBulk {
}) })
} }
// SetWechat sets the "wechat" field.
func (u *UserUpsertBulk) SetWechat(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetWechat(v)
})
}
// UpdateWechat sets the "wechat" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateWechat() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateWechat()
})
}
// SetNotes sets the "notes" field. // SetNotes sets the "notes" field.
func (u *UserUpsertBulk) SetNotes(v string) *UserUpsertBulk { func (u *UserUpsertBulk) SetNotes(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) { return u.Update(func(s *UserUpsert) {

View File

@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
@@ -35,6 +36,7 @@ type UserQuery struct {
withAssignedSubscriptions *UserSubscriptionQuery withAssignedSubscriptions *UserSubscriptionQuery
withAllowedGroups *GroupQuery withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
withUserAllowedGroups *UserAllowedGroupQuery withUserAllowedGroups *UserAllowedGroupQuery
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
@@ -204,6 +206,28 @@ func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery {
return query return query
} }
// QueryAttributeValues chains the current query on the "attribute_values" edge.
func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AttributeValuesTable, user.AttributeValuesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query() query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -424,6 +448,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(), withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(), withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(), withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query. // clone intermediate query.
sql: _q.sql.Clone(), sql: _q.sql.Clone(),
@@ -497,6 +522,17 @@ func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery {
return _q return _q
} }
// WithAttributeValues tells the query-builder to eager-load the nodes that are connected to
// the "attribute_values" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery)) *UserQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAttributeValues = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -586,13 +622,14 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var ( var (
nodes = []*User{} nodes = []*User{}
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [7]bool{ loadedTypes = [8]bool{
_q.withAPIKeys != nil, _q.withAPIKeys != nil,
_q.withRedeemCodes != nil, _q.withRedeemCodes != nil,
_q.withSubscriptions != nil, _q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil, _q.withAssignedSubscriptions != nil,
_q.withAllowedGroups != nil, _q.withAllowedGroups != nil,
_q.withUsageLogs != nil, _q.withUsageLogs != nil,
_q.withAttributeValues != nil,
_q.withUserAllowedGroups != nil, _q.withUserAllowedGroups != nil,
} }
) )
@@ -658,6 +695,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err return nil, err
} }
} }
if query := _q.withAttributeValues; query != nil {
if err := _q.loadAttributeValues(ctx, query, nodes,
func(n *User) { n.Edges.AttributeValues = []*UserAttributeValue{} },
func(n *User, e *UserAttributeValue) { n.Edges.AttributeValues = append(n.Edges.AttributeValues, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil { if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes, if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -885,6 +929,36 @@ func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, no
} }
return nil return nil
} }
func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttributeValueQuery, nodes []*User, init func(*User), assign func(*User, *UserAttributeValue)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(userattributevalue.FieldUserID)
}
query.Where(predicate.UserAttributeValue(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AttributeValuesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes)) fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User) nodeids := make(map[int64]*User)

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
) )
@@ -171,20 +172,6 @@ func (_u *UserUpdate) SetNillableUsername(v *string) *UserUpdate {
return _u return _u
} }
// SetWechat sets the "wechat" field.
func (_u *UserUpdate) SetWechat(v string) *UserUpdate {
_u.mutation.SetWechat(v)
return _u
}
// SetNillableWechat sets the "wechat" field if the given value is not nil.
func (_u *UserUpdate) SetNillableWechat(v *string) *UserUpdate {
if v != nil {
_u.SetWechat(*v)
}
return _u
}
// SetNotes sets the "notes" field. // SetNotes sets the "notes" field.
func (_u *UserUpdate) SetNotes(v string) *UserUpdate { func (_u *UserUpdate) SetNotes(v string) *UserUpdate {
_u.mutation.SetNotes(v) _u.mutation.SetNotes(v)
@@ -289,6 +276,21 @@ func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate {
return _u.AddUsageLogIDs(ids...) return _u.AddUsageLogIDs(ids...)
} }
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
func (_u *UserUpdate) AddAttributeValueIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAttributeValueIDs(ids...)
return _u
}
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAttributeValueIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation { func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation return _u.mutation
@@ -420,6 +422,27 @@ func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate {
return _u.RemoveUsageLogIDs(ids...) return _u.RemoveUsageLogIDs(ids...)
} }
// ClearAttributeValues clears all "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdate) ClearAttributeValues() *UserUpdate {
_u.mutation.ClearAttributeValues()
return _u
}
// RemoveAttributeValueIDs removes the "attribute_values" edge to UserAttributeValue entities by IDs.
func (_u *UserUpdate) RemoveAttributeValueIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAttributeValueIDs(ids...)
return _u
}
// RemoveAttributeValues removes "attribute_values" edges to UserAttributeValue entities.
func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAttributeValueIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation. // Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) { func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil { if err := _u.defaults(); err != nil {
@@ -489,11 +512,6 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
} }
} }
if v, ok := _u.mutation.Wechat(); ok {
if err := user.WechatValidator(v); err != nil {
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
}
}
return nil return nil
} }
@@ -545,9 +563,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Username(); ok { if value, ok := _u.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value) _spec.SetField(user.FieldUsername, field.TypeString, value)
} }
if value, ok := _u.mutation.Wechat(); ok {
_spec.SetField(user.FieldWechat, field.TypeString, value)
}
if value, ok := _u.mutation.Notes(); ok { if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value) _spec.SetField(user.FieldNotes, field.TypeString, value)
} }
@@ -833,6 +848,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAttributeValuesIDs(); len(nodes) > 0 && !_u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AttributeValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label} err = &NotFoundError{user.Label}
@@ -991,20 +1051,6 @@ func (_u *UserUpdateOne) SetNillableUsername(v *string) *UserUpdateOne {
return _u return _u
} }
// SetWechat sets the "wechat" field.
func (_u *UserUpdateOne) SetWechat(v string) *UserUpdateOne {
_u.mutation.SetWechat(v)
return _u
}
// SetNillableWechat sets the "wechat" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableWechat(v *string) *UserUpdateOne {
if v != nil {
_u.SetWechat(*v)
}
return _u
}
// SetNotes sets the "notes" field. // SetNotes sets the "notes" field.
func (_u *UserUpdateOne) SetNotes(v string) *UserUpdateOne { func (_u *UserUpdateOne) SetNotes(v string) *UserUpdateOne {
_u.mutation.SetNotes(v) _u.mutation.SetNotes(v)
@@ -1109,6 +1155,21 @@ func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne {
return _u.AddUsageLogIDs(ids...) return _u.AddUsageLogIDs(ids...)
} }
// AddAttributeValueIDs adds the "attribute_values" edge to the UserAttributeValue entity by IDs.
func (_u *UserUpdateOne) AddAttributeValueIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAttributeValueIDs(ids...)
return _u
}
// AddAttributeValues adds the "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAttributeValueIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation { func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation return _u.mutation
@@ -1240,6 +1301,27 @@ func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne {
return _u.RemoveUsageLogIDs(ids...) return _u.RemoveUsageLogIDs(ids...)
} }
// ClearAttributeValues clears all "attribute_values" edges to the UserAttributeValue entity.
func (_u *UserUpdateOne) ClearAttributeValues() *UserUpdateOne {
_u.mutation.ClearAttributeValues()
return _u
}
// RemoveAttributeValueIDs removes the "attribute_values" edge to UserAttributeValue entities by IDs.
func (_u *UserUpdateOne) RemoveAttributeValueIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAttributeValueIDs(ids...)
return _u
}
// RemoveAttributeValues removes "attribute_values" edges to UserAttributeValue entities.
func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAttributeValueIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder. // Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...) _u.mutation.Where(ps...)
@@ -1322,11 +1404,6 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
} }
} }
if v, ok := _u.mutation.Wechat(); ok {
if err := user.WechatValidator(v); err != nil {
return &ValidationError{Name: "wechat", err: fmt.Errorf(`ent: validator failed for field "User.wechat": %w`, err)}
}
}
return nil return nil
} }
@@ -1395,9 +1472,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Username(); ok { if value, ok := _u.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value) _spec.SetField(user.FieldUsername, field.TypeString, value)
} }
if value, ok := _u.mutation.Wechat(); ok {
_spec.SetField(user.FieldWechat, field.TypeString, value)
}
if value, ok := _u.mutation.Notes(); ok { if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value) _spec.SetField(user.FieldNotes, field.TypeString, value)
} }
@@ -1683,6 +1757,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAttributeValuesIDs(); len(nodes) > 0 && !_u.mutation.AttributeValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AttributeValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AttributeValuesTable,
Columns: []string{user.AttributeValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config} _node = &User{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues

View File

@@ -0,0 +1,276 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
)
// UserAttributeDefinition is the model entity for the UserAttributeDefinition schema.
type UserAttributeDefinition struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// Key holds the value of the "key" field.
Key string `json:"key,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Description holds the value of the "description" field.
Description string `json:"description,omitempty"`
// Type holds the value of the "type" field.
Type string `json:"type,omitempty"`
// Options holds the value of the "options" field.
Options []map[string]interface{} `json:"options,omitempty"`
// Required holds the value of the "required" field.
Required bool `json:"required,omitempty"`
// Validation holds the value of the "validation" field.
Validation map[string]interface{} `json:"validation,omitempty"`
// Placeholder holds the value of the "placeholder" field.
Placeholder string `json:"placeholder,omitempty"`
// DisplayOrder holds the value of the "display_order" field.
DisplayOrder int `json:"display_order,omitempty"`
// Enabled holds the value of the "enabled" field.
Enabled bool `json:"enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserAttributeDefinitionQuery when eager-loading is set.
Edges UserAttributeDefinitionEdges `json:"edges"`
selectValues sql.SelectValues
}
// UserAttributeDefinitionEdges holds the relations/edges for other nodes in the graph.
type UserAttributeDefinitionEdges struct {
// Values holds the value of the values edge.
Values []*UserAttributeValue `json:"values,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// ValuesOrErr returns the Values value or an error if the edge
// was not loaded in eager-loading.
func (e UserAttributeDefinitionEdges) ValuesOrErr() ([]*UserAttributeValue, error) {
if e.loadedTypes[0] {
return e.Values, nil
}
return nil, &NotLoadedError{edge: "values"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserAttributeDefinition) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case userattributedefinition.FieldOptions, userattributedefinition.FieldValidation:
values[i] = new([]byte)
case userattributedefinition.FieldRequired, userattributedefinition.FieldEnabled:
values[i] = new(sql.NullBool)
case userattributedefinition.FieldID, userattributedefinition.FieldDisplayOrder:
values[i] = new(sql.NullInt64)
case userattributedefinition.FieldKey, userattributedefinition.FieldName, userattributedefinition.FieldDescription, userattributedefinition.FieldType, userattributedefinition.FieldPlaceholder:
values[i] = new(sql.NullString)
case userattributedefinition.FieldCreatedAt, userattributedefinition.FieldUpdatedAt, userattributedefinition.FieldDeletedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UserAttributeDefinition fields.
func (_m *UserAttributeDefinition) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case userattributedefinition.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case userattributedefinition.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case userattributedefinition.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case userattributedefinition.FieldDeletedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
} else if value.Valid {
_m.DeletedAt = new(time.Time)
*_m.DeletedAt = value.Time
}
case userattributedefinition.FieldKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field key", values[i])
} else if value.Valid {
_m.Key = value.String
}
case userattributedefinition.FieldName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field name", values[i])
} else if value.Valid {
_m.Name = value.String
}
case userattributedefinition.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
} else if value.Valid {
_m.Description = value.String
}
case userattributedefinition.FieldType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field type", values[i])
} else if value.Valid {
_m.Type = value.String
}
case userattributedefinition.FieldOptions:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field options", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Options); err != nil {
return fmt.Errorf("unmarshal field options: %w", err)
}
}
case userattributedefinition.FieldRequired:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field required", values[i])
} else if value.Valid {
_m.Required = value.Bool
}
case userattributedefinition.FieldValidation:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field validation", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Validation); err != nil {
return fmt.Errorf("unmarshal field validation: %w", err)
}
}
case userattributedefinition.FieldPlaceholder:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field placeholder", values[i])
} else if value.Valid {
_m.Placeholder = value.String
}
case userattributedefinition.FieldDisplayOrder:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field display_order", values[i])
} else if value.Valid {
_m.DisplayOrder = int(value.Int64)
}
case userattributedefinition.FieldEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field enabled", values[i])
} else if value.Valid {
_m.Enabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UserAttributeDefinition.
// This includes values selected through modifiers, order, etc.
func (_m *UserAttributeDefinition) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryValues queries the "values" edge of the UserAttributeDefinition entity.
func (_m *UserAttributeDefinition) QueryValues() *UserAttributeValueQuery {
return NewUserAttributeDefinitionClient(_m.config).QueryValues(_m)
}
// Update returns a builder for updating this UserAttributeDefinition.
// Note that you need to call UserAttributeDefinition.Unwrap() before calling this method if this UserAttributeDefinition
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UserAttributeDefinition) Update() *UserAttributeDefinitionUpdateOne {
return NewUserAttributeDefinitionClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UserAttributeDefinition entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UserAttributeDefinition) Unwrap() *UserAttributeDefinition {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UserAttributeDefinition is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UserAttributeDefinition) String() string {
var builder strings.Builder
builder.WriteString("UserAttributeDefinition(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.DeletedAt; v != nil {
builder.WriteString("deleted_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("key=")
builder.WriteString(_m.Key)
builder.WriteString(", ")
builder.WriteString("name=")
builder.WriteString(_m.Name)
builder.WriteString(", ")
builder.WriteString("description=")
builder.WriteString(_m.Description)
builder.WriteString(", ")
builder.WriteString("type=")
builder.WriteString(_m.Type)
builder.WriteString(", ")
builder.WriteString("options=")
builder.WriteString(fmt.Sprintf("%v", _m.Options))
builder.WriteString(", ")
builder.WriteString("required=")
builder.WriteString(fmt.Sprintf("%v", _m.Required))
builder.WriteString(", ")
builder.WriteString("validation=")
builder.WriteString(fmt.Sprintf("%v", _m.Validation))
builder.WriteString(", ")
builder.WriteString("placeholder=")
builder.WriteString(_m.Placeholder)
builder.WriteString(", ")
builder.WriteString("display_order=")
builder.WriteString(fmt.Sprintf("%v", _m.DisplayOrder))
builder.WriteString(", ")
builder.WriteString("enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
builder.WriteByte(')')
return builder.String()
}
// UserAttributeDefinitions is a parsable slice of UserAttributeDefinition.
type UserAttributeDefinitions []*UserAttributeDefinition

View File

@@ -0,0 +1,205 @@
// Code generated by ent, DO NOT EDIT.
package userattributedefinition
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the userattributedefinition type in the database.
Label = "user_attribute_definition"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldKey holds the string denoting the key field in the database.
FieldKey = "key"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// FieldType holds the string denoting the type field in the database.
FieldType = "type"
// FieldOptions holds the string denoting the options field in the database.
FieldOptions = "options"
// FieldRequired holds the string denoting the required field in the database.
FieldRequired = "required"
// FieldValidation holds the string denoting the validation field in the database.
FieldValidation = "validation"
// FieldPlaceholder holds the string denoting the placeholder field in the database.
FieldPlaceholder = "placeholder"
// FieldDisplayOrder holds the string denoting the display_order field in the database.
FieldDisplayOrder = "display_order"
// FieldEnabled holds the string denoting the enabled field in the database.
FieldEnabled = "enabled"
// EdgeValues holds the string denoting the values edge name in mutations.
EdgeValues = "values"
// Table holds the table name of the userattributedefinition in the database.
Table = "user_attribute_definitions"
// ValuesTable is the table that holds the values relation/edge.
ValuesTable = "user_attribute_values"
// ValuesInverseTable is the table name for the UserAttributeValue entity.
// It exists in this package in order to avoid circular dependency with the "userattributevalue" package.
ValuesInverseTable = "user_attribute_values"
// ValuesColumn is the table column denoting the values relation/edge.
ValuesColumn = "attribute_id"
)
// Columns holds all SQL columns for userattributedefinition fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldDeletedAt,
FieldKey,
FieldName,
FieldDescription,
FieldType,
FieldOptions,
FieldRequired,
FieldValidation,
FieldPlaceholder,
FieldDisplayOrder,
FieldEnabled,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// KeyValidator is a validator for the "key" field. It is called by the builders before save.
KeyValidator func(string) error
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
// DefaultDescription holds the default value on creation for the "description" field.
DefaultDescription string
// TypeValidator is a validator for the "type" field. It is called by the builders before save.
TypeValidator func(string) error
// DefaultOptions holds the default value on creation for the "options" field.
DefaultOptions []map[string]interface{}
// DefaultRequired holds the default value on creation for the "required" field.
DefaultRequired bool
// DefaultValidation holds the default value on creation for the "validation" field.
DefaultValidation map[string]interface{}
// DefaultPlaceholder holds the default value on creation for the "placeholder" field.
DefaultPlaceholder string
// PlaceholderValidator is a validator for the "placeholder" field. It is called by the builders before save.
PlaceholderValidator func(string) error
// DefaultDisplayOrder holds the default value on creation for the "display_order" field.
DefaultDisplayOrder int
// DefaultEnabled holds the default value on creation for the "enabled" field.
DefaultEnabled bool
)
// OrderOption defines the ordering options for the UserAttributeDefinition queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByKey orders the results by the key field.
func ByKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldKey, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()
}
// ByType orders the results by the type field.
func ByType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldType, opts...).ToFunc()
}
// ByRequired orders the results by the required field.
func ByRequired(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRequired, opts...).ToFunc()
}
// ByPlaceholder orders the results by the placeholder field.
func ByPlaceholder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPlaceholder, opts...).ToFunc()
}
// ByDisplayOrder orders the results by the display_order field.
func ByDisplayOrder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDisplayOrder, opts...).ToFunc()
}
// ByEnabled orders the results by the enabled field.
func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEnabled, opts...).ToFunc()
}
// ByValuesCount orders the results by values count.
func ByValuesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newValuesStep(), opts...)
}
}
// ByValues orders the results by values terms.
func ByValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newValuesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newValuesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(ValuesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ValuesTable, ValuesColumn),
)
}

View File

@@ -0,0 +1,664 @@
// Code generated by ent, DO NOT EDIT.
package userattributedefinition
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldUpdatedAt, v))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDeletedAt, v))
}
// Key applies equality check predicate on the "key" field. It's identical to KeyEQ.
func Key(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldKey, v))
}
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
func Name(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldName, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDescription, v))
}
// Type applies equality check predicate on the "type" field. It's identical to TypeEQ.
func Type(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldType, v))
}
// Required applies equality check predicate on the "required" field. It's identical to RequiredEQ.
func Required(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldRequired, v))
}
// Placeholder applies equality check predicate on the "placeholder" field. It's identical to PlaceholderEQ.
func Placeholder(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldPlaceholder, v))
}
// DisplayOrder applies equality check predicate on the "display_order" field. It's identical to DisplayOrderEQ.
func DisplayOrder(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDisplayOrder, v))
}
// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
func Enabled(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldUpdatedAt, v))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotNull(FieldDeletedAt))
}
// KeyEQ applies the EQ predicate on the "key" field.
func KeyEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldKey, v))
}
// KeyNEQ applies the NEQ predicate on the "key" field.
func KeyNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldKey, v))
}
// KeyIn applies the In predicate on the "key" field.
func KeyIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldKey, vs...))
}
// KeyNotIn applies the NotIn predicate on the "key" field.
func KeyNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldKey, vs...))
}
// KeyGT applies the GT predicate on the "key" field.
func KeyGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldKey, v))
}
// KeyGTE applies the GTE predicate on the "key" field.
func KeyGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldKey, v))
}
// KeyLT applies the LT predicate on the "key" field.
func KeyLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldKey, v))
}
// KeyLTE applies the LTE predicate on the "key" field.
func KeyLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldKey, v))
}
// KeyContains applies the Contains predicate on the "key" field.
func KeyContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldKey, v))
}
// KeyHasPrefix applies the HasPrefix predicate on the "key" field.
func KeyHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldKey, v))
}
// KeyHasSuffix applies the HasSuffix predicate on the "key" field.
func KeyHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldKey, v))
}
// KeyEqualFold applies the EqualFold predicate on the "key" field.
func KeyEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldKey, v))
}
// KeyContainsFold applies the ContainsFold predicate on the "key" field.
func KeyContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldKey, v))
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldName, v))
}
// NameNEQ applies the NEQ predicate on the "name" field.
func NameNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldName, v))
}
// NameIn applies the In predicate on the "name" field.
func NameIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldName, vs...))
}
// NameNotIn applies the NotIn predicate on the "name" field.
func NameNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldName, vs...))
}
// NameGT applies the GT predicate on the "name" field.
func NameGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldName, v))
}
// NameGTE applies the GTE predicate on the "name" field.
func NameGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldName, v))
}
// NameLT applies the LT predicate on the "name" field.
func NameLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldName, v))
}
// NameLTE applies the LTE predicate on the "name" field.
func NameLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldName, v))
}
// NameContains applies the Contains predicate on the "name" field.
func NameContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldName, v))
}
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
func NameHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldName, v))
}
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
func NameHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldName, v))
}
// NameEqualFold applies the EqualFold predicate on the "name" field.
func NameEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldName, v))
}
// NameContainsFold applies the ContainsFold predicate on the "name" field.
func NameContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldName, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDescription, v))
}
// DescriptionNEQ applies the NEQ predicate on the "description" field.
func DescriptionNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDescription, v))
}
// DescriptionIn applies the In predicate on the "description" field.
func DescriptionIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDescription, vs...))
}
// DescriptionNotIn applies the NotIn predicate on the "description" field.
func DescriptionNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDescription, vs...))
}
// DescriptionGT applies the GT predicate on the "description" field.
func DescriptionGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDescription, v))
}
// DescriptionGTE applies the GTE predicate on the "description" field.
func DescriptionGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDescription, v))
}
// DescriptionLT applies the LT predicate on the "description" field.
func DescriptionLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDescription, v))
}
// DescriptionLTE applies the LTE predicate on the "description" field.
func DescriptionLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDescription, v))
}
// DescriptionContains applies the Contains predicate on the "description" field.
func DescriptionContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldDescription, v))
}
// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
func DescriptionHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldDescription, v))
}
// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
func DescriptionHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldDescription, v))
}
// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
func DescriptionEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldDescription, v))
}
// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
func DescriptionContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldDescription, v))
}
// TypeEQ applies the EQ predicate on the "type" field.
func TypeEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldType, v))
}
// TypeNEQ applies the NEQ predicate on the "type" field.
func TypeNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldType, v))
}
// TypeIn applies the In predicate on the "type" field.
func TypeIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldType, vs...))
}
// TypeNotIn applies the NotIn predicate on the "type" field.
func TypeNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldType, vs...))
}
// TypeGT applies the GT predicate on the "type" field.
func TypeGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldType, v))
}
// TypeGTE applies the GTE predicate on the "type" field.
func TypeGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldType, v))
}
// TypeLT applies the LT predicate on the "type" field.
func TypeLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldType, v))
}
// TypeLTE applies the LTE predicate on the "type" field.
func TypeLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldType, v))
}
// TypeContains applies the Contains predicate on the "type" field.
func TypeContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldType, v))
}
// TypeHasPrefix applies the HasPrefix predicate on the "type" field.
func TypeHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldType, v))
}
// TypeHasSuffix applies the HasSuffix predicate on the "type" field.
func TypeHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldType, v))
}
// TypeEqualFold applies the EqualFold predicate on the "type" field.
func TypeEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldType, v))
}
// TypeContainsFold applies the ContainsFold predicate on the "type" field.
func TypeContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldType, v))
}
// RequiredEQ applies the EQ predicate on the "required" field.
func RequiredEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldRequired, v))
}
// RequiredNEQ applies the NEQ predicate on the "required" field.
func RequiredNEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldRequired, v))
}
// PlaceholderEQ applies the EQ predicate on the "placeholder" field.
func PlaceholderEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldPlaceholder, v))
}
// PlaceholderNEQ applies the NEQ predicate on the "placeholder" field.
func PlaceholderNEQ(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldPlaceholder, v))
}
// PlaceholderIn applies the In predicate on the "placeholder" field.
func PlaceholderIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldPlaceholder, vs...))
}
// PlaceholderNotIn applies the NotIn predicate on the "placeholder" field.
func PlaceholderNotIn(vs ...string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldPlaceholder, vs...))
}
// PlaceholderGT applies the GT predicate on the "placeholder" field.
func PlaceholderGT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldPlaceholder, v))
}
// PlaceholderGTE applies the GTE predicate on the "placeholder" field.
func PlaceholderGTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldPlaceholder, v))
}
// PlaceholderLT applies the LT predicate on the "placeholder" field.
func PlaceholderLT(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldPlaceholder, v))
}
// PlaceholderLTE applies the LTE predicate on the "placeholder" field.
func PlaceholderLTE(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldPlaceholder, v))
}
// PlaceholderContains applies the Contains predicate on the "placeholder" field.
func PlaceholderContains(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContains(FieldPlaceholder, v))
}
// PlaceholderHasPrefix applies the HasPrefix predicate on the "placeholder" field.
func PlaceholderHasPrefix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasPrefix(FieldPlaceholder, v))
}
// PlaceholderHasSuffix applies the HasSuffix predicate on the "placeholder" field.
func PlaceholderHasSuffix(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldHasSuffix(FieldPlaceholder, v))
}
// PlaceholderEqualFold applies the EqualFold predicate on the "placeholder" field.
func PlaceholderEqualFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEqualFold(FieldPlaceholder, v))
}
// PlaceholderContainsFold applies the ContainsFold predicate on the "placeholder" field.
func PlaceholderContainsFold(v string) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldContainsFold(FieldPlaceholder, v))
}
// DisplayOrderEQ applies the EQ predicate on the "display_order" field.
func DisplayOrderEQ(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldDisplayOrder, v))
}
// DisplayOrderNEQ applies the NEQ predicate on the "display_order" field.
func DisplayOrderNEQ(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldDisplayOrder, v))
}
// DisplayOrderIn applies the In predicate on the "display_order" field.
func DisplayOrderIn(vs ...int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldIn(FieldDisplayOrder, vs...))
}
// DisplayOrderNotIn applies the NotIn predicate on the "display_order" field.
func DisplayOrderNotIn(vs ...int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNotIn(FieldDisplayOrder, vs...))
}
// DisplayOrderGT applies the GT predicate on the "display_order" field.
func DisplayOrderGT(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGT(FieldDisplayOrder, v))
}
// DisplayOrderGTE applies the GTE predicate on the "display_order" field.
func DisplayOrderGTE(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldGTE(FieldDisplayOrder, v))
}
// DisplayOrderLT applies the LT predicate on the "display_order" field.
func DisplayOrderLT(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLT(FieldDisplayOrder, v))
}
// DisplayOrderLTE applies the LTE predicate on the "display_order" field.
func DisplayOrderLTE(v int) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldLTE(FieldDisplayOrder, v))
}
// EnabledEQ applies the EQ predicate on the "enabled" field.
func EnabledEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldEQ(FieldEnabled, v))
}
// EnabledNEQ applies the NEQ predicate on the "enabled" field.
func EnabledNEQ(v bool) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.FieldNEQ(FieldEnabled, v))
}
// HasValues applies the HasEdge predicate on the "values" edge.
func HasValues() predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ValuesTable, ValuesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasValuesWith applies the HasEdge predicate on the "values" edge with a given conditions (other predicates).
func HasValuesWith(preds ...predicate.UserAttributeValue) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(func(s *sql.Selector) {
step := newValuesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserAttributeDefinition) predicate.UserAttributeDefinition {
return predicate.UserAttributeDefinition(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
)
// UserAttributeDefinitionDelete is the builder for deleting a UserAttributeDefinition entity.
type UserAttributeDefinitionDelete struct {
config
hooks []Hook
mutation *UserAttributeDefinitionMutation
}
// Where appends a list predicates to the UserAttributeDefinitionDelete builder.
func (_d *UserAttributeDefinitionDelete) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserAttributeDefinitionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeDefinitionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserAttributeDefinitionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(userattributedefinition.Table, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserAttributeDefinitionDeleteOne is the builder for deleting a single UserAttributeDefinition entity.
type UserAttributeDefinitionDeleteOne struct {
_d *UserAttributeDefinitionDelete
}
// Where appends a list predicates to the UserAttributeDefinitionDelete builder.
func (_d *UserAttributeDefinitionDeleteOne) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserAttributeDefinitionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{userattributedefinition.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeDefinitionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,606 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeDefinitionQuery is the builder for querying UserAttributeDefinition entities.
type UserAttributeDefinitionQuery struct {
config
ctx *QueryContext
order []userattributedefinition.OrderOption
inters []Interceptor
predicates []predicate.UserAttributeDefinition
withValues *UserAttributeValueQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserAttributeDefinitionQuery builder.
func (_q *UserAttributeDefinitionQuery) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserAttributeDefinitionQuery) Limit(limit int) *UserAttributeDefinitionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserAttributeDefinitionQuery) Offset(offset int) *UserAttributeDefinitionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserAttributeDefinitionQuery) Unique(unique bool) *UserAttributeDefinitionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserAttributeDefinitionQuery) Order(o ...userattributedefinition.OrderOption) *UserAttributeDefinitionQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryValues chains the current query on the "values" edge.
func (_q *UserAttributeDefinitionQuery) QueryValues() *UserAttributeValueQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userattributedefinition.Table, userattributedefinition.FieldID, selector),
sqlgraph.To(userattributevalue.Table, userattributevalue.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, userattributedefinition.ValuesTable, userattributedefinition.ValuesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserAttributeDefinition entity from the query.
// Returns a *NotFoundError when no UserAttributeDefinition was found.
func (_q *UserAttributeDefinitionQuery) First(ctx context.Context) (*UserAttributeDefinition, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{userattributedefinition.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) FirstX(ctx context.Context) *UserAttributeDefinition {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserAttributeDefinition ID from the query.
// Returns a *NotFoundError when no UserAttributeDefinition ID was found.
func (_q *UserAttributeDefinitionQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{userattributedefinition.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserAttributeDefinition entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserAttributeDefinition entity is found.
// Returns a *NotFoundError when no UserAttributeDefinition entities are found.
func (_q *UserAttributeDefinitionQuery) Only(ctx context.Context) (*UserAttributeDefinition, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{userattributedefinition.Label}
default:
return nil, &NotSingularError{userattributedefinition.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) OnlyX(ctx context.Context) *UserAttributeDefinition {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserAttributeDefinition ID in the query.
// Returns a *NotSingularError when more than one UserAttributeDefinition ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserAttributeDefinitionQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{userattributedefinition.Label}
default:
err = &NotSingularError{userattributedefinition.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserAttributeDefinitions.
func (_q *UserAttributeDefinitionQuery) All(ctx context.Context) ([]*UserAttributeDefinition, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserAttributeDefinition, *UserAttributeDefinitionQuery]()
return withInterceptors[[]*UserAttributeDefinition](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) AllX(ctx context.Context) []*UserAttributeDefinition {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserAttributeDefinition IDs.
func (_q *UserAttributeDefinitionQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(userattributedefinition.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserAttributeDefinitionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserAttributeDefinitionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserAttributeDefinitionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserAttributeDefinitionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserAttributeDefinitionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserAttributeDefinitionQuery) Clone() *UserAttributeDefinitionQuery {
if _q == nil {
return nil
}
return &UserAttributeDefinitionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]userattributedefinition.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserAttributeDefinition{}, _q.predicates...),
withValues: _q.withValues.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithValues tells the query-builder to eager-load the nodes that are connected to
// the "values" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserAttributeDefinitionQuery) WithValues(opts ...func(*UserAttributeValueQuery)) *UserAttributeDefinitionQuery {
query := (&UserAttributeValueClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withValues = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserAttributeDefinition.Query().
// GroupBy(userattributedefinition.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserAttributeDefinitionQuery) GroupBy(field string, fields ...string) *UserAttributeDefinitionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserAttributeDefinitionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = userattributedefinition.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserAttributeDefinition.Query().
// Select(userattributedefinition.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserAttributeDefinitionQuery) Select(fields ...string) *UserAttributeDefinitionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserAttributeDefinitionSelect{UserAttributeDefinitionQuery: _q}
sbuild.label = userattributedefinition.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserAttributeDefinitionSelect configured with the given aggregations.
func (_q *UserAttributeDefinitionQuery) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserAttributeDefinitionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !userattributedefinition.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserAttributeDefinition, error) {
var (
nodes = []*UserAttributeDefinition{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withValues != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserAttributeDefinition).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserAttributeDefinition{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withValues; query != nil {
if err := _q.loadValues(ctx, query, nodes,
func(n *UserAttributeDefinition) { n.Edges.Values = []*UserAttributeValue{} },
func(n *UserAttributeDefinition, e *UserAttributeValue) { n.Edges.Values = append(n.Edges.Values, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *UserAttributeValueQuery, nodes []*UserAttributeDefinition, init func(*UserAttributeDefinition), assign func(*UserAttributeDefinition, *UserAttributeValue)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*UserAttributeDefinition)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(userattributevalue.FieldAttributeID)
}
query.Where(predicate.UserAttributeValue(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(userattributedefinition.ValuesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.AttributeID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "attribute_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserAttributeDefinitionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributedefinition.FieldID)
for i := range fields {
if fields[i] != userattributedefinition.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(userattributedefinition.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = userattributedefinition.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
type UserAttributeDefinitionGroupBy struct {
selector
build *UserAttributeDefinitionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserAttributeDefinitionGroupBy) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserAttributeDefinitionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeDefinitionQuery, *UserAttributeDefinitionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserAttributeDefinitionGroupBy) sqlScan(ctx context.Context, root *UserAttributeDefinitionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserAttributeDefinitionSelect is the builder for selecting fields of UserAttributeDefinition entities.
type UserAttributeDefinitionSelect struct {
*UserAttributeDefinitionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserAttributeDefinitionSelect) Aggregate(fns ...AggregateFunc) *UserAttributeDefinitionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserAttributeDefinitionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeDefinitionQuery, *UserAttributeDefinitionSelect](ctx, _s.UserAttributeDefinitionQuery, _s, _s.inters, v)
}
func (_s *UserAttributeDefinitionSelect) sqlScan(ctx context.Context, root *UserAttributeDefinitionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,846 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeDefinitionUpdate is the builder for updating UserAttributeDefinition entities.
type UserAttributeDefinitionUpdate struct {
config
hooks []Hook
mutation *UserAttributeDefinitionMutation
}
// Where appends a list predicates to the UserAttributeDefinitionUpdate builder.
func (_u *UserAttributeDefinitionUpdate) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeDefinitionUpdate) SetUpdatedAt(v time.Time) *UserAttributeDefinitionUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdate) SetDeletedAt(v time.Time) *UserAttributeDefinitionUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableDeletedAt(v *time.Time) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdate) ClearDeletedAt() *UserAttributeDefinitionUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetKey sets the "key" field.
func (_u *UserAttributeDefinitionUpdate) SetKey(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableKey(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetName sets the "name" field.
func (_u *UserAttributeDefinitionUpdate) SetName(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetName(v)
return _u
}
// SetNillableName sets the "name" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableName(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetName(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *UserAttributeDefinitionUpdate) SetDescription(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetDescription(v)
return _u
}
// SetNillableDescription sets the "description" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableDescription(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetDescription(*v)
}
return _u
}
// SetType sets the "type" field.
func (_u *UserAttributeDefinitionUpdate) SetType(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetType(v)
return _u
}
// SetNillableType sets the "type" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableType(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetType(*v)
}
return _u
}
// SetOptions sets the "options" field.
func (_u *UserAttributeDefinitionUpdate) SetOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdate {
_u.mutation.SetOptions(v)
return _u
}
// AppendOptions appends value to the "options" field.
func (_u *UserAttributeDefinitionUpdate) AppendOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdate {
_u.mutation.AppendOptions(v)
return _u
}
// SetRequired sets the "required" field.
func (_u *UserAttributeDefinitionUpdate) SetRequired(v bool) *UserAttributeDefinitionUpdate {
_u.mutation.SetRequired(v)
return _u
}
// SetNillableRequired sets the "required" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableRequired(v *bool) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetRequired(*v)
}
return _u
}
// SetValidation sets the "validation" field.
func (_u *UserAttributeDefinitionUpdate) SetValidation(v map[string]interface{}) *UserAttributeDefinitionUpdate {
_u.mutation.SetValidation(v)
return _u
}
// SetPlaceholder sets the "placeholder" field.
func (_u *UserAttributeDefinitionUpdate) SetPlaceholder(v string) *UserAttributeDefinitionUpdate {
_u.mutation.SetPlaceholder(v)
return _u
}
// SetNillablePlaceholder sets the "placeholder" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillablePlaceholder(v *string) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetPlaceholder(*v)
}
return _u
}
// SetDisplayOrder sets the "display_order" field.
func (_u *UserAttributeDefinitionUpdate) SetDisplayOrder(v int) *UserAttributeDefinitionUpdate {
_u.mutation.ResetDisplayOrder()
_u.mutation.SetDisplayOrder(v)
return _u
}
// SetNillableDisplayOrder sets the "display_order" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableDisplayOrder(v *int) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetDisplayOrder(*v)
}
return _u
}
// AddDisplayOrder adds value to the "display_order" field.
func (_u *UserAttributeDefinitionUpdate) AddDisplayOrder(v int) *UserAttributeDefinitionUpdate {
_u.mutation.AddDisplayOrder(v)
return _u
}
// SetEnabled sets the "enabled" field.
func (_u *UserAttributeDefinitionUpdate) SetEnabled(v bool) *UserAttributeDefinitionUpdate {
_u.mutation.SetEnabled(v)
return _u
}
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdate) SetNillableEnabled(v *bool) *UserAttributeDefinitionUpdate {
if v != nil {
_u.SetEnabled(*v)
}
return _u
}
// AddValueIDs adds the "values" edge to the UserAttributeValue entity by IDs.
func (_u *UserAttributeDefinitionUpdate) AddValueIDs(ids ...int64) *UserAttributeDefinitionUpdate {
_u.mutation.AddValueIDs(ids...)
return _u
}
// AddValues adds the "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdate) AddValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddValueIDs(ids...)
}
// Mutation returns the UserAttributeDefinitionMutation object of the builder.
func (_u *UserAttributeDefinitionUpdate) Mutation() *UserAttributeDefinitionMutation {
return _u.mutation
}
// ClearValues clears all "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdate) ClearValues() *UserAttributeDefinitionUpdate {
_u.mutation.ClearValues()
return _u
}
// RemoveValueIDs removes the "values" edge to UserAttributeValue entities by IDs.
func (_u *UserAttributeDefinitionUpdate) RemoveValueIDs(ids ...int64) *UserAttributeDefinitionUpdate {
_u.mutation.RemoveValueIDs(ids...)
return _u
}
// RemoveValues removes "values" edges to UserAttributeValue entities.
func (_u *UserAttributeDefinitionUpdate) RemoveValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveValueIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserAttributeDefinitionUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserAttributeDefinitionUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeDefinitionUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userattributedefinition.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userattributedefinition.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userattributedefinition.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeDefinitionUpdate) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := userattributedefinition.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.key": %w`, err)}
}
}
if v, ok := _u.mutation.Name(); ok {
if err := userattributedefinition.NameValidator(v); err != nil {
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.name": %w`, err)}
}
}
if v, ok := _u.mutation.GetType(); ok {
if err := userattributedefinition.TypeValidator(v); err != nil {
return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.type": %w`, err)}
}
}
if v, ok := _u.mutation.Placeholder(); ok {
if err := userattributedefinition.PlaceholderValidator(v); err != nil {
return &ValidationError{Name: "placeholder", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.placeholder": %w`, err)}
}
}
return nil
}
func (_u *UserAttributeDefinitionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributedefinition.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userattributedefinition.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userattributedefinition.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(userattributedefinition.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(userattributedefinition.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(userattributedefinition.FieldDescription, field.TypeString, value)
}
if value, ok := _u.mutation.GetType(); ok {
_spec.SetField(userattributedefinition.FieldType, field.TypeString, value)
}
if value, ok := _u.mutation.Options(); ok {
_spec.SetField(userattributedefinition.FieldOptions, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedOptions(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, userattributedefinition.FieldOptions, value)
})
}
if value, ok := _u.mutation.Required(); ok {
_spec.SetField(userattributedefinition.FieldRequired, field.TypeBool, value)
}
if value, ok := _u.mutation.Validation(); ok {
_spec.SetField(userattributedefinition.FieldValidation, field.TypeJSON, value)
}
if value, ok := _u.mutation.Placeholder(); ok {
_spec.SetField(userattributedefinition.FieldPlaceholder, field.TypeString, value)
}
if value, ok := _u.mutation.DisplayOrder(); ok {
_spec.SetField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedDisplayOrder(); ok {
_spec.AddField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.Enabled(); ok {
_spec.SetField(userattributedefinition.FieldEnabled, field.TypeBool, value)
}
if _u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedValuesIDs(); len(nodes) > 0 && !_u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributedefinition.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserAttributeDefinitionUpdateOne is the builder for updating a single UserAttributeDefinition entity.
type UserAttributeDefinitionUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserAttributeDefinitionMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeDefinitionUpdateOne) SetUpdatedAt(v time.Time) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdateOne) SetDeletedAt(v time.Time) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserAttributeDefinitionUpdateOne) ClearDeletedAt() *UserAttributeDefinitionUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetKey sets the "key" field.
func (_u *UserAttributeDefinitionUpdateOne) SetKey(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableKey(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetName sets the "name" field.
func (_u *UserAttributeDefinitionUpdateOne) SetName(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetName(v)
return _u
}
// SetNillableName sets the "name" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableName(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetName(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *UserAttributeDefinitionUpdateOne) SetDescription(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetDescription(v)
return _u
}
// SetNillableDescription sets the "description" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDescription(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetDescription(*v)
}
return _u
}
// SetType sets the "type" field.
func (_u *UserAttributeDefinitionUpdateOne) SetType(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetType(v)
return _u
}
// SetNillableType sets the "type" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableType(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetType(*v)
}
return _u
}
// SetOptions sets the "options" field.
func (_u *UserAttributeDefinitionUpdateOne) SetOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetOptions(v)
return _u
}
// AppendOptions appends value to the "options" field.
func (_u *UserAttributeDefinitionUpdateOne) AppendOptions(v []map[string]interface{}) *UserAttributeDefinitionUpdateOne {
_u.mutation.AppendOptions(v)
return _u
}
// SetRequired sets the "required" field.
func (_u *UserAttributeDefinitionUpdateOne) SetRequired(v bool) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetRequired(v)
return _u
}
// SetNillableRequired sets the "required" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableRequired(v *bool) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetRequired(*v)
}
return _u
}
// SetValidation sets the "validation" field.
func (_u *UserAttributeDefinitionUpdateOne) SetValidation(v map[string]interface{}) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetValidation(v)
return _u
}
// SetPlaceholder sets the "placeholder" field.
func (_u *UserAttributeDefinitionUpdateOne) SetPlaceholder(v string) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetPlaceholder(v)
return _u
}
// SetNillablePlaceholder sets the "placeholder" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillablePlaceholder(v *string) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetPlaceholder(*v)
}
return _u
}
// SetDisplayOrder sets the "display_order" field.
func (_u *UserAttributeDefinitionUpdateOne) SetDisplayOrder(v int) *UserAttributeDefinitionUpdateOne {
_u.mutation.ResetDisplayOrder()
_u.mutation.SetDisplayOrder(v)
return _u
}
// SetNillableDisplayOrder sets the "display_order" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableDisplayOrder(v *int) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetDisplayOrder(*v)
}
return _u
}
// AddDisplayOrder adds value to the "display_order" field.
func (_u *UserAttributeDefinitionUpdateOne) AddDisplayOrder(v int) *UserAttributeDefinitionUpdateOne {
_u.mutation.AddDisplayOrder(v)
return _u
}
// SetEnabled sets the "enabled" field.
func (_u *UserAttributeDefinitionUpdateOne) SetEnabled(v bool) *UserAttributeDefinitionUpdateOne {
_u.mutation.SetEnabled(v)
return _u
}
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
func (_u *UserAttributeDefinitionUpdateOne) SetNillableEnabled(v *bool) *UserAttributeDefinitionUpdateOne {
if v != nil {
_u.SetEnabled(*v)
}
return _u
}
// AddValueIDs adds the "values" edge to the UserAttributeValue entity by IDs.
func (_u *UserAttributeDefinitionUpdateOne) AddValueIDs(ids ...int64) *UserAttributeDefinitionUpdateOne {
_u.mutation.AddValueIDs(ids...)
return _u
}
// AddValues adds the "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdateOne) AddValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddValueIDs(ids...)
}
// Mutation returns the UserAttributeDefinitionMutation object of the builder.
func (_u *UserAttributeDefinitionUpdateOne) Mutation() *UserAttributeDefinitionMutation {
return _u.mutation
}
// ClearValues clears all "values" edges to the UserAttributeValue entity.
func (_u *UserAttributeDefinitionUpdateOne) ClearValues() *UserAttributeDefinitionUpdateOne {
_u.mutation.ClearValues()
return _u
}
// RemoveValueIDs removes the "values" edge to UserAttributeValue entities by IDs.
func (_u *UserAttributeDefinitionUpdateOne) RemoveValueIDs(ids ...int64) *UserAttributeDefinitionUpdateOne {
_u.mutation.RemoveValueIDs(ids...)
return _u
}
// RemoveValues removes "values" edges to UserAttributeValue entities.
func (_u *UserAttributeDefinitionUpdateOne) RemoveValues(v ...*UserAttributeValue) *UserAttributeDefinitionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveValueIDs(ids...)
}
// Where appends a list predicates to the UserAttributeDefinitionUpdate builder.
func (_u *UserAttributeDefinitionUpdateOne) Where(ps ...predicate.UserAttributeDefinition) *UserAttributeDefinitionUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserAttributeDefinitionUpdateOne) Select(field string, fields ...string) *UserAttributeDefinitionUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserAttributeDefinition entity.
func (_u *UserAttributeDefinitionUpdateOne) Save(ctx context.Context) (*UserAttributeDefinition, error) {
if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdateOne) SaveX(ctx context.Context) *UserAttributeDefinition {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserAttributeDefinitionUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeDefinitionUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeDefinitionUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userattributedefinition.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userattributedefinition.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userattributedefinition.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeDefinitionUpdateOne) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := userattributedefinition.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.key": %w`, err)}
}
}
if v, ok := _u.mutation.Name(); ok {
if err := userattributedefinition.NameValidator(v); err != nil {
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.name": %w`, err)}
}
}
if v, ok := _u.mutation.GetType(); ok {
if err := userattributedefinition.TypeValidator(v); err != nil {
return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.type": %w`, err)}
}
}
if v, ok := _u.mutation.Placeholder(); ok {
if err := userattributedefinition.PlaceholderValidator(v); err != nil {
return &ValidationError{Name: "placeholder", err: fmt.Errorf(`ent: validator failed for field "UserAttributeDefinition.placeholder": %w`, err)}
}
}
return nil
}
func (_u *UserAttributeDefinitionUpdateOne) sqlSave(ctx context.Context) (_node *UserAttributeDefinition, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributedefinition.Table, userattributedefinition.Columns, sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserAttributeDefinition.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributedefinition.FieldID)
for _, f := range fields {
if !userattributedefinition.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != userattributedefinition.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributedefinition.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userattributedefinition.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userattributedefinition.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(userattributedefinition.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(userattributedefinition.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(userattributedefinition.FieldDescription, field.TypeString, value)
}
if value, ok := _u.mutation.GetType(); ok {
_spec.SetField(userattributedefinition.FieldType, field.TypeString, value)
}
if value, ok := _u.mutation.Options(); ok {
_spec.SetField(userattributedefinition.FieldOptions, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedOptions(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, userattributedefinition.FieldOptions, value)
})
}
if value, ok := _u.mutation.Required(); ok {
_spec.SetField(userattributedefinition.FieldRequired, field.TypeBool, value)
}
if value, ok := _u.mutation.Validation(); ok {
_spec.SetField(userattributedefinition.FieldValidation, field.TypeJSON, value)
}
if value, ok := _u.mutation.Placeholder(); ok {
_spec.SetField(userattributedefinition.FieldPlaceholder, field.TypeString, value)
}
if value, ok := _u.mutation.DisplayOrder(); ok {
_spec.SetField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedDisplayOrder(); ok {
_spec.AddField(userattributedefinition.FieldDisplayOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.Enabled(); ok {
_spec.SetField(userattributedefinition.FieldEnabled, field.TypeBool, value)
}
if _u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedValuesIDs(); len(nodes) > 0 && !_u.mutation.ValuesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ValuesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: userattributedefinition.ValuesTable,
Columns: []string{userattributedefinition.ValuesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserAttributeDefinition{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributedefinition.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -0,0 +1,198 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValue is the model entity for the UserAttributeValue schema.
type UserAttributeValue struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// AttributeID holds the value of the "attribute_id" field.
AttributeID int64 `json:"attribute_id,omitempty"`
// Value holds the value of the "value" field.
Value string `json:"value,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserAttributeValueQuery when eager-loading is set.
Edges UserAttributeValueEdges `json:"edges"`
selectValues sql.SelectValues
}
// UserAttributeValueEdges holds the relations/edges for other nodes in the graph.
type UserAttributeValueEdges struct {
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// Definition holds the value of the definition edge.
Definition *UserAttributeDefinition `json:"definition,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UserAttributeValueEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// DefinitionOrErr returns the Definition value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UserAttributeValueEdges) DefinitionOrErr() (*UserAttributeDefinition, error) {
if e.Definition != nil {
return e.Definition, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: userattributedefinition.Label}
}
return nil, &NotLoadedError{edge: "definition"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserAttributeValue) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case userattributevalue.FieldID, userattributevalue.FieldUserID, userattributevalue.FieldAttributeID:
values[i] = new(sql.NullInt64)
case userattributevalue.FieldValue:
values[i] = new(sql.NullString)
case userattributevalue.FieldCreatedAt, userattributevalue.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UserAttributeValue fields.
func (_m *UserAttributeValue) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case userattributevalue.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case userattributevalue.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case userattributevalue.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case userattributevalue.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case userattributevalue.FieldAttributeID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field attribute_id", values[i])
} else if value.Valid {
_m.AttributeID = value.Int64
}
case userattributevalue.FieldValue:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field value", values[i])
} else if value.Valid {
_m.Value = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// GetValue returns the ent.Value that was dynamically selected and assigned to the UserAttributeValue.
// This includes values selected through modifiers, order, etc.
func (_m *UserAttributeValue) GetValue(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryUser queries the "user" edge of the UserAttributeValue entity.
func (_m *UserAttributeValue) QueryUser() *UserQuery {
return NewUserAttributeValueClient(_m.config).QueryUser(_m)
}
// QueryDefinition queries the "definition" edge of the UserAttributeValue entity.
func (_m *UserAttributeValue) QueryDefinition() *UserAttributeDefinitionQuery {
return NewUserAttributeValueClient(_m.config).QueryDefinition(_m)
}
// Update returns a builder for updating this UserAttributeValue.
// Note that you need to call UserAttributeValue.Unwrap() before calling this method if this UserAttributeValue
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UserAttributeValue) Update() *UserAttributeValueUpdateOne {
return NewUserAttributeValueClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UserAttributeValue entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UserAttributeValue) Unwrap() *UserAttributeValue {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UserAttributeValue is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UserAttributeValue) String() string {
var builder strings.Builder
builder.WriteString("UserAttributeValue(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("attribute_id=")
builder.WriteString(fmt.Sprintf("%v", _m.AttributeID))
builder.WriteString(", ")
builder.WriteString("value=")
builder.WriteString(_m.Value)
builder.WriteByte(')')
return builder.String()
}
// UserAttributeValues is a parsable slice of UserAttributeValue.
type UserAttributeValues []*UserAttributeValue

View File

@@ -0,0 +1,139 @@
// Code generated by ent, DO NOT EDIT.
package userattributevalue
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the userattributevalue type in the database.
Label = "user_attribute_value"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldAttributeID holds the string denoting the attribute_id field in the database.
FieldAttributeID = "attribute_id"
// FieldValue holds the string denoting the value field in the database.
FieldValue = "value"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeDefinition holds the string denoting the definition edge name in mutations.
EdgeDefinition = "definition"
// Table holds the table name of the userattributevalue in the database.
Table = "user_attribute_values"
// UserTable is the table that holds the user relation/edge.
UserTable = "user_attribute_values"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
// DefinitionTable is the table that holds the definition relation/edge.
DefinitionTable = "user_attribute_values"
// DefinitionInverseTable is the table name for the UserAttributeDefinition entity.
// It exists in this package in order to avoid circular dependency with the "userattributedefinition" package.
DefinitionInverseTable = "user_attribute_definitions"
// DefinitionColumn is the table column denoting the definition relation/edge.
DefinitionColumn = "attribute_id"
)
// Columns holds all SQL columns for userattributevalue fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldUserID,
FieldAttributeID,
FieldValue,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// DefaultValue holds the default value on creation for the "value" field.
DefaultValue string
)
// OrderOption defines the ordering options for the UserAttributeValue queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByAttributeID orders the results by the attribute_id field.
func ByAttributeID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAttributeID, opts...).ToFunc()
}
// ByValue orders the results by the value field.
func ByValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldValue, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
// ByDefinitionField orders the results by definition field.
func ByDefinitionField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newDefinitionStep(), sql.OrderByField(field, opts...))
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}
func newDefinitionStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(DefinitionInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, DefinitionTable, DefinitionColumn),
)
}

View File

@@ -0,0 +1,327 @@
// Code generated by ent, DO NOT EDIT.
package userattributevalue
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUpdatedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUserID, v))
}
// AttributeID applies equality check predicate on the "attribute_id" field. It's identical to AttributeIDEQ.
func AttributeID(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldAttributeID, v))
}
// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
func Value(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldValue, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldUpdatedAt, v))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldUserID, vs...))
}
// AttributeIDEQ applies the EQ predicate on the "attribute_id" field.
func AttributeIDEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldAttributeID, v))
}
// AttributeIDNEQ applies the NEQ predicate on the "attribute_id" field.
func AttributeIDNEQ(v int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldAttributeID, v))
}
// AttributeIDIn applies the In predicate on the "attribute_id" field.
func AttributeIDIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldAttributeID, vs...))
}
// AttributeIDNotIn applies the NotIn predicate on the "attribute_id" field.
func AttributeIDNotIn(vs ...int64) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldAttributeID, vs...))
}
// ValueEQ applies the EQ predicate on the "value" field.
func ValueEQ(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEQ(FieldValue, v))
}
// ValueNEQ applies the NEQ predicate on the "value" field.
func ValueNEQ(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNEQ(FieldValue, v))
}
// ValueIn applies the In predicate on the "value" field.
func ValueIn(vs ...string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldIn(FieldValue, vs...))
}
// ValueNotIn applies the NotIn predicate on the "value" field.
func ValueNotIn(vs ...string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldNotIn(FieldValue, vs...))
}
// ValueGT applies the GT predicate on the "value" field.
func ValueGT(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGT(FieldValue, v))
}
// ValueGTE applies the GTE predicate on the "value" field.
func ValueGTE(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldGTE(FieldValue, v))
}
// ValueLT applies the LT predicate on the "value" field.
func ValueLT(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLT(FieldValue, v))
}
// ValueLTE applies the LTE predicate on the "value" field.
func ValueLTE(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldLTE(FieldValue, v))
}
// ValueContains applies the Contains predicate on the "value" field.
func ValueContains(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldContains(FieldValue, v))
}
// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
func ValueHasPrefix(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldHasPrefix(FieldValue, v))
}
// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
func ValueHasSuffix(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldHasSuffix(FieldValue, v))
}
// ValueEqualFold applies the EqualFold predicate on the "value" field.
func ValueEqualFold(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldEqualFold(FieldValue, v))
}
// ValueContainsFold applies the ContainsFold predicate on the "value" field.
func ValueContainsFold(v string) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.FieldContainsFold(FieldValue, v))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasDefinition applies the HasEdge predicate on the "definition" edge.
func HasDefinition() predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, DefinitionTable, DefinitionColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasDefinitionWith applies the HasEdge predicate on the "definition" edge with a given conditions (other predicates).
func HasDefinitionWith(preds ...predicate.UserAttributeDefinition) predicate.UserAttributeValue {
return predicate.UserAttributeValue(func(s *sql.Selector) {
step := newDefinitionStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserAttributeValue) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserAttributeValue) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserAttributeValue) predicate.UserAttributeValue {
return predicate.UserAttributeValue(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,731 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueCreate is the builder for creating a UserAttributeValue entity.
type UserAttributeValueCreate struct {
config
mutation *UserAttributeValueMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *UserAttributeValueCreate) SetCreatedAt(v time.Time) *UserAttributeValueCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *UserAttributeValueCreate) SetNillableCreatedAt(v *time.Time) *UserAttributeValueCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *UserAttributeValueCreate) SetUpdatedAt(v time.Time) *UserAttributeValueCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *UserAttributeValueCreate) SetNillableUpdatedAt(v *time.Time) *UserAttributeValueCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetUserID sets the "user_id" field.
func (_c *UserAttributeValueCreate) SetUserID(v int64) *UserAttributeValueCreate {
_c.mutation.SetUserID(v)
return _c
}
// SetAttributeID sets the "attribute_id" field.
func (_c *UserAttributeValueCreate) SetAttributeID(v int64) *UserAttributeValueCreate {
_c.mutation.SetAttributeID(v)
return _c
}
// SetValue sets the "value" field.
func (_c *UserAttributeValueCreate) SetValue(v string) *UserAttributeValueCreate {
_c.mutation.SetValue(v)
return _c
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_c *UserAttributeValueCreate) SetNillableValue(v *string) *UserAttributeValueCreate {
if v != nil {
_c.SetValue(*v)
}
return _c
}
// SetUser sets the "user" edge to the User entity.
func (_c *UserAttributeValueCreate) SetUser(v *User) *UserAttributeValueCreate {
return _c.SetUserID(v.ID)
}
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
func (_c *UserAttributeValueCreate) SetDefinitionID(id int64) *UserAttributeValueCreate {
_c.mutation.SetDefinitionID(id)
return _c
}
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
func (_c *UserAttributeValueCreate) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueCreate {
return _c.SetDefinitionID(v.ID)
}
// Mutation returns the UserAttributeValueMutation object of the builder.
func (_c *UserAttributeValueCreate) Mutation() *UserAttributeValueMutation {
return _c.mutation
}
// Save creates the UserAttributeValue in the database.
func (_c *UserAttributeValueCreate) Save(ctx context.Context) (*UserAttributeValue, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *UserAttributeValueCreate) SaveX(ctx context.Context) *UserAttributeValue {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *UserAttributeValueCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *UserAttributeValueCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *UserAttributeValueCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := userattributevalue.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := userattributevalue.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.Value(); !ok {
v := userattributevalue.DefaultValue
_c.mutation.SetValue(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *UserAttributeValueCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UserAttributeValue.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UserAttributeValue.updated_at"`)}
}
if _, ok := _c.mutation.UserID(); !ok {
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UserAttributeValue.user_id"`)}
}
if _, ok := _c.mutation.AttributeID(); !ok {
return &ValidationError{Name: "attribute_id", err: errors.New(`ent: missing required field "UserAttributeValue.attribute_id"`)}
}
if _, ok := _c.mutation.Value(); !ok {
return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "UserAttributeValue.value"`)}
}
if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UserAttributeValue.user"`)}
}
if len(_c.mutation.DefinitionIDs()) == 0 {
return &ValidationError{Name: "definition", err: errors.New(`ent: missing required edge "UserAttributeValue.definition"`)}
}
return nil
}
func (_c *UserAttributeValueCreate) sqlSave(ctx context.Context) (*UserAttributeValue, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *UserAttributeValueCreate) createSpec() (*UserAttributeValue, *sqlgraph.CreateSpec) {
var (
_node = &UserAttributeValue{config: _c.config}
_spec = sqlgraph.NewCreateSpec(userattributevalue.Table, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(userattributevalue.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.Value(); ok {
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
_node.Value = value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.UserID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.DefinitionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.AttributeID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.UserAttributeValue.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.UserAttributeValueUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *UserAttributeValueCreate) OnConflict(opts ...sql.ConflictOption) *UserAttributeValueUpsertOne {
_c.conflict = opts
return &UserAttributeValueUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *UserAttributeValueCreate) OnConflictColumns(columns ...string) *UserAttributeValueUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &UserAttributeValueUpsertOne{
create: _c,
}
}
type (
// UserAttributeValueUpsertOne is the builder for "upsert"-ing
// one UserAttributeValue node.
UserAttributeValueUpsertOne struct {
create *UserAttributeValueCreate
}
// UserAttributeValueUpsert is the "OnConflict" setter.
UserAttributeValueUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *UserAttributeValueUpsert) SetUpdatedAt(v time.Time) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateUpdatedAt() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldUpdatedAt)
return u
}
// SetUserID sets the "user_id" field.
func (u *UserAttributeValueUpsert) SetUserID(v int64) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldUserID, v)
return u
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateUserID() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldUserID)
return u
}
// SetAttributeID sets the "attribute_id" field.
func (u *UserAttributeValueUpsert) SetAttributeID(v int64) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldAttributeID, v)
return u
}
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateAttributeID() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldAttributeID)
return u
}
// SetValue sets the "value" field.
func (u *UserAttributeValueUpsert) SetValue(v string) *UserAttributeValueUpsert {
u.Set(userattributevalue.FieldValue, v)
return u
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *UserAttributeValueUpsert) UpdateValue() *UserAttributeValueUpsert {
u.SetExcluded(userattributevalue.FieldValue)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *UserAttributeValueUpsertOne) UpdateNewValues() *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(userattributevalue.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *UserAttributeValueUpsertOne) Ignore() *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *UserAttributeValueUpsertOne) DoNothing() *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the UserAttributeValueCreate.OnConflict
// documentation for more info.
func (u *UserAttributeValueUpsertOne) Update(set func(*UserAttributeValueUpsert)) *UserAttributeValueUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&UserAttributeValueUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *UserAttributeValueUpsertOne) SetUpdatedAt(v time.Time) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateUpdatedAt() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUpdatedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserAttributeValueUpsertOne) SetUserID(v int64) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateUserID() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUserID()
})
}
// SetAttributeID sets the "attribute_id" field.
func (u *UserAttributeValueUpsertOne) SetAttributeID(v int64) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetAttributeID(v)
})
}
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateAttributeID() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateAttributeID()
})
}
// SetValue sets the "value" field.
func (u *UserAttributeValueUpsertOne) SetValue(v string) *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *UserAttributeValueUpsertOne) UpdateValue() *UserAttributeValueUpsertOne {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *UserAttributeValueUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for UserAttributeValueCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *UserAttributeValueUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *UserAttributeValueUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *UserAttributeValueUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// UserAttributeValueCreateBulk is the builder for creating many UserAttributeValue entities in bulk.
type UserAttributeValueCreateBulk struct {
config
err error
builders []*UserAttributeValueCreate
conflict []sql.ConflictOption
}
// Save creates the UserAttributeValue entities in the database.
func (_c *UserAttributeValueCreateBulk) Save(ctx context.Context) ([]*UserAttributeValue, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*UserAttributeValue, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*UserAttributeValueMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *UserAttributeValueCreateBulk) SaveX(ctx context.Context) []*UserAttributeValue {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *UserAttributeValueCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *UserAttributeValueCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.UserAttributeValue.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.UserAttributeValueUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *UserAttributeValueCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserAttributeValueUpsertBulk {
_c.conflict = opts
return &UserAttributeValueUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *UserAttributeValueCreateBulk) OnConflictColumns(columns ...string) *UserAttributeValueUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &UserAttributeValueUpsertBulk{
create: _c,
}
}
// UserAttributeValueUpsertBulk is the builder for "upsert"-ing
// a bulk of UserAttributeValue nodes.
type UserAttributeValueUpsertBulk struct {
create *UserAttributeValueCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *UserAttributeValueUpsertBulk) UpdateNewValues() *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(userattributevalue.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.UserAttributeValue.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *UserAttributeValueUpsertBulk) Ignore() *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *UserAttributeValueUpsertBulk) DoNothing() *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the UserAttributeValueCreateBulk.OnConflict
// documentation for more info.
func (u *UserAttributeValueUpsertBulk) Update(set func(*UserAttributeValueUpsert)) *UserAttributeValueUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&UserAttributeValueUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *UserAttributeValueUpsertBulk) SetUpdatedAt(v time.Time) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateUpdatedAt() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUpdatedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserAttributeValueUpsertBulk) SetUserID(v int64) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateUserID() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateUserID()
})
}
// SetAttributeID sets the "attribute_id" field.
func (u *UserAttributeValueUpsertBulk) SetAttributeID(v int64) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetAttributeID(v)
})
}
// UpdateAttributeID sets the "attribute_id" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateAttributeID() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateAttributeID()
})
}
// SetValue sets the "value" field.
func (u *UserAttributeValueUpsertBulk) SetValue(v string) *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *UserAttributeValueUpsertBulk) UpdateValue() *UserAttributeValueUpsertBulk {
return u.Update(func(s *UserAttributeValueUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *UserAttributeValueUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserAttributeValueCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for UserAttributeValueCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *UserAttributeValueUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueDelete is the builder for deleting a UserAttributeValue entity.
type UserAttributeValueDelete struct {
config
hooks []Hook
mutation *UserAttributeValueMutation
}
// Where appends a list predicates to the UserAttributeValueDelete builder.
func (_d *UserAttributeValueDelete) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserAttributeValueDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeValueDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserAttributeValueDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(userattributevalue.Table, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserAttributeValueDeleteOne is the builder for deleting a single UserAttributeValue entity.
type UserAttributeValueDeleteOne struct {
_d *UserAttributeValueDelete
}
// Where appends a list predicates to the UserAttributeValueDelete builder.
func (_d *UserAttributeValueDeleteOne) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserAttributeValueDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{userattributevalue.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserAttributeValueDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,681 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueQuery is the builder for querying UserAttributeValue entities.
type UserAttributeValueQuery struct {
config
ctx *QueryContext
order []userattributevalue.OrderOption
inters []Interceptor
predicates []predicate.UserAttributeValue
withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserAttributeValueQuery builder.
func (_q *UserAttributeValueQuery) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserAttributeValueQuery) Limit(limit int) *UserAttributeValueQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserAttributeValueQuery) Offset(offset int) *UserAttributeValueQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserAttributeValueQuery) Unique(unique bool) *UserAttributeValueQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserAttributeValueQuery) Order(o ...userattributevalue.OrderOption) *UserAttributeValueQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UserAttributeValueQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.UserTable, userattributevalue.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryDefinition chains the current query on the "definition" edge.
func (_q *UserAttributeValueQuery) QueryDefinition() *UserAttributeDefinitionQuery {
query := (&UserAttributeDefinitionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userattributevalue.Table, userattributevalue.FieldID, selector),
sqlgraph.To(userattributedefinition.Table, userattributedefinition.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userattributevalue.DefinitionTable, userattributevalue.DefinitionColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserAttributeValue entity from the query.
// Returns a *NotFoundError when no UserAttributeValue was found.
func (_q *UserAttributeValueQuery) First(ctx context.Context) (*UserAttributeValue, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{userattributevalue.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserAttributeValueQuery) FirstX(ctx context.Context) *UserAttributeValue {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserAttributeValue ID from the query.
// Returns a *NotFoundError when no UserAttributeValue ID was found.
func (_q *UserAttributeValueQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{userattributevalue.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserAttributeValueQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserAttributeValue entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserAttributeValue entity is found.
// Returns a *NotFoundError when no UserAttributeValue entities are found.
func (_q *UserAttributeValueQuery) Only(ctx context.Context) (*UserAttributeValue, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{userattributevalue.Label}
default:
return nil, &NotSingularError{userattributevalue.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserAttributeValueQuery) OnlyX(ctx context.Context) *UserAttributeValue {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserAttributeValue ID in the query.
// Returns a *NotSingularError when more than one UserAttributeValue ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserAttributeValueQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{userattributevalue.Label}
default:
err = &NotSingularError{userattributevalue.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserAttributeValueQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserAttributeValues.
func (_q *UserAttributeValueQuery) All(ctx context.Context) ([]*UserAttributeValue, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserAttributeValue, *UserAttributeValueQuery]()
return withInterceptors[[]*UserAttributeValue](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserAttributeValueQuery) AllX(ctx context.Context) []*UserAttributeValue {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserAttributeValue IDs.
func (_q *UserAttributeValueQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(userattributevalue.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserAttributeValueQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserAttributeValueQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserAttributeValueQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserAttributeValueQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserAttributeValueQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserAttributeValueQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserAttributeValueQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserAttributeValueQuery) Clone() *UserAttributeValueQuery {
if _q == nil {
return nil
}
return &UserAttributeValueQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]userattributevalue.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserAttributeValue{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withDefinition: _q.withDefinition.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserAttributeValueQuery) WithUser(opts ...func(*UserQuery)) *UserAttributeValueQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// WithDefinition tells the query-builder to eager-load the nodes that are connected to
// the "definition" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserAttributeValueQuery) WithDefinition(opts ...func(*UserAttributeDefinitionQuery)) *UserAttributeValueQuery {
query := (&UserAttributeDefinitionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withDefinition = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserAttributeValue.Query().
// GroupBy(userattributevalue.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserAttributeValueQuery) GroupBy(field string, fields ...string) *UserAttributeValueGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserAttributeValueGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = userattributevalue.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserAttributeValue.Query().
// Select(userattributevalue.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserAttributeValueQuery) Select(fields ...string) *UserAttributeValueSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserAttributeValueSelect{UserAttributeValueQuery: _q}
sbuild.label = userattributevalue.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserAttributeValueSelect configured with the given aggregations.
func (_q *UserAttributeValueQuery) Aggregate(fns ...AggregateFunc) *UserAttributeValueSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserAttributeValueQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !userattributevalue.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserAttributeValue, error) {
var (
nodes = []*UserAttributeValue{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
_q.withUser != nil,
_q.withDefinition != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserAttributeValue).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserAttributeValue{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UserAttributeValue, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
if query := _q.withDefinition; query != nil {
if err := _q.loadDefinition(ctx, query, nodes, nil,
func(n *UserAttributeValue, e *UserAttributeDefinition) { n.Edges.Definition = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserAttributeValueQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserAttributeValue, init func(*UserAttributeValue), assign func(*UserAttributeValue, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserAttributeValue)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *UserAttributeDefinitionQuery, nodes []*UserAttributeValue, init func(*UserAttributeValue), assign func(*UserAttributeValue, *UserAttributeDefinition)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserAttributeValue)
for i := range nodes {
fk := nodes[i].AttributeID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(userattributedefinition.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "attribute_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserAttributeValueQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributevalue.FieldID)
for i := range fields {
if fields[i] != userattributevalue.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(userattributevalue.FieldUserID)
}
if _q.withDefinition != nil {
_spec.Node.AddColumnOnce(userattributevalue.FieldAttributeID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(userattributevalue.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = userattributevalue.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct {
selector
build *UserAttributeValueQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserAttributeValueGroupBy) Aggregate(fns ...AggregateFunc) *UserAttributeValueGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserAttributeValueGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeValueQuery, *UserAttributeValueGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserAttributeValueGroupBy) sqlScan(ctx context.Context, root *UserAttributeValueQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserAttributeValueSelect is the builder for selecting fields of UserAttributeValue entities.
type UserAttributeValueSelect struct {
*UserAttributeValueQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserAttributeValueSelect) Aggregate(fns ...AggregateFunc) *UserAttributeValueSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserAttributeValueSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserAttributeValueQuery, *UserAttributeValueSelect](ctx, _s.UserAttributeValueQuery, _s, _s.inters, v)
}
func (_s *UserAttributeValueSelect) sqlScan(ctx context.Context, root *UserAttributeValueQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,504 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
)
// UserAttributeValueUpdate is the builder for updating UserAttributeValue entities.
type UserAttributeValueUpdate struct {
config
hooks []Hook
mutation *UserAttributeValueMutation
}
// Where appends a list predicates to the UserAttributeValueUpdate builder.
func (_u *UserAttributeValueUpdate) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeValueUpdate) SetUpdatedAt(v time.Time) *UserAttributeValueUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserAttributeValueUpdate) SetUserID(v int64) *UserAttributeValueUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdate) SetNillableUserID(v *int64) *UserAttributeValueUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetAttributeID sets the "attribute_id" field.
func (_u *UserAttributeValueUpdate) SetAttributeID(v int64) *UserAttributeValueUpdate {
_u.mutation.SetAttributeID(v)
return _u
}
// SetNillableAttributeID sets the "attribute_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdate) SetNillableAttributeID(v *int64) *UserAttributeValueUpdate {
if v != nil {
_u.SetAttributeID(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *UserAttributeValueUpdate) SetValue(v string) *UserAttributeValueUpdate {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *UserAttributeValueUpdate) SetNillableValue(v *string) *UserAttributeValueUpdate {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserAttributeValueUpdate) SetUser(v *User) *UserAttributeValueUpdate {
return _u.SetUserID(v.ID)
}
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
func (_u *UserAttributeValueUpdate) SetDefinitionID(id int64) *UserAttributeValueUpdate {
_u.mutation.SetDefinitionID(id)
return _u
}
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdate) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueUpdate {
return _u.SetDefinitionID(v.ID)
}
// Mutation returns the UserAttributeValueMutation object of the builder.
func (_u *UserAttributeValueUpdate) Mutation() *UserAttributeValueMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserAttributeValueUpdate) ClearUser() *UserAttributeValueUpdate {
_u.mutation.ClearUser()
return _u
}
// ClearDefinition clears the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdate) ClearDefinition() *UserAttributeValueUpdate {
_u.mutation.ClearDefinition()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserAttributeValueUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeValueUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserAttributeValueUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeValueUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeValueUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := userattributevalue.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeValueUpdate) check() error {
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.user"`)
}
if _u.mutation.DefinitionCleared() && len(_u.mutation.DefinitionIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.definition"`)
}
return nil
}
func (_u *UserAttributeValueUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.DefinitionCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.DefinitionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributevalue.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserAttributeValueUpdateOne is the builder for updating a single UserAttributeValue entity.
type UserAttributeValueUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserAttributeValueMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserAttributeValueUpdateOne) SetUpdatedAt(v time.Time) *UserAttributeValueUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserAttributeValueUpdateOne) SetUserID(v int64) *UserAttributeValueUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdateOne) SetNillableUserID(v *int64) *UserAttributeValueUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetAttributeID sets the "attribute_id" field.
func (_u *UserAttributeValueUpdateOne) SetAttributeID(v int64) *UserAttributeValueUpdateOne {
_u.mutation.SetAttributeID(v)
return _u
}
// SetNillableAttributeID sets the "attribute_id" field if the given value is not nil.
func (_u *UserAttributeValueUpdateOne) SetNillableAttributeID(v *int64) *UserAttributeValueUpdateOne {
if v != nil {
_u.SetAttributeID(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *UserAttributeValueUpdateOne) SetValue(v string) *UserAttributeValueUpdateOne {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *UserAttributeValueUpdateOne) SetNillableValue(v *string) *UserAttributeValueUpdateOne {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserAttributeValueUpdateOne) SetUser(v *User) *UserAttributeValueUpdateOne {
return _u.SetUserID(v.ID)
}
// SetDefinitionID sets the "definition" edge to the UserAttributeDefinition entity by ID.
func (_u *UserAttributeValueUpdateOne) SetDefinitionID(id int64) *UserAttributeValueUpdateOne {
_u.mutation.SetDefinitionID(id)
return _u
}
// SetDefinition sets the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdateOne) SetDefinition(v *UserAttributeDefinition) *UserAttributeValueUpdateOne {
return _u.SetDefinitionID(v.ID)
}
// Mutation returns the UserAttributeValueMutation object of the builder.
func (_u *UserAttributeValueUpdateOne) Mutation() *UserAttributeValueMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserAttributeValueUpdateOne) ClearUser() *UserAttributeValueUpdateOne {
_u.mutation.ClearUser()
return _u
}
// ClearDefinition clears the "definition" edge to the UserAttributeDefinition entity.
func (_u *UserAttributeValueUpdateOne) ClearDefinition() *UserAttributeValueUpdateOne {
_u.mutation.ClearDefinition()
return _u
}
// Where appends a list predicates to the UserAttributeValueUpdate builder.
func (_u *UserAttributeValueUpdateOne) Where(ps ...predicate.UserAttributeValue) *UserAttributeValueUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserAttributeValueUpdateOne) Select(field string, fields ...string) *UserAttributeValueUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserAttributeValue entity.
func (_u *UserAttributeValueUpdateOne) Save(ctx context.Context) (*UserAttributeValue, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserAttributeValueUpdateOne) SaveX(ctx context.Context) *UserAttributeValue {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserAttributeValueUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserAttributeValueUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserAttributeValueUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := userattributevalue.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserAttributeValueUpdateOne) check() error {
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.user"`)
}
if _u.mutation.DefinitionCleared() && len(_u.mutation.DefinitionIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserAttributeValue.definition"`)
}
return nil
}
func (_u *UserAttributeValueUpdateOne) sqlSave(ctx context.Context) (_node *UserAttributeValue, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userattributevalue.Table, userattributevalue.Columns, sqlgraph.NewFieldSpec(userattributevalue.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserAttributeValue.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userattributevalue.FieldID)
for _, f := range fields {
if !userattributevalue.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != userattributevalue.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userattributevalue.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(userattributevalue.FieldValue, field.TypeString, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.UserTable,
Columns: []string{userattributevalue.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.DefinitionCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.DefinitionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userattributevalue.DefinitionTable,
Columns: []string{userattributevalue.DefinitionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userattributedefinition.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserAttributeValue{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userattributevalue.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@@ -43,6 +44,7 @@ type Config struct {
type GeminiConfig struct { type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"` OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
} }
type GeminiOAuthConfig struct { type GeminiOAuthConfig struct {
@@ -51,6 +53,17 @@ type GeminiOAuthConfig struct {
Scopes string `mapstructure:"scopes"` Scopes string `mapstructure:"scopes"`
} }
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置 // TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct { type TokenRefreshConfig struct {
// 是否启用自动刷新 // 是否启用自动刷新
@@ -119,6 +132,37 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期 // 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type GatewaySchedulingConfig struct {
// 粘性会话排队配置
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
} }
func (s *ServerConfig) Address() string { func (s *ServerConfig) Address() string {
@@ -313,6 +357,10 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
@@ -323,6 +371,12 @@ func setDefaults() {
viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
@@ -337,6 +391,7 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_id", "") viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -411,6 +466,21 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
} }
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
return nil return nil
} }

View File

@@ -1,6 +1,11 @@
package config package config
import "testing" import (
"testing"
"time"
"github.com/spf13/viper"
)
func TestNormalizeRunMode(t *testing.T) { func TestNormalizeRunMode(t *testing.T) {
tests := []struct { tests := []struct {
@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
} }
} }
} }
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
}
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}

View File

@@ -3,6 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -13,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
) )
// OAuthHandler handles OAuth-related operations for accounts // OAuthHandler handles OAuth-related operations for accounts
@@ -989,3 +991,164 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models) response.Success(c, models)
} }
// RefreshTier handles refreshing Google One tier for a single account
// POST /api/v1/admin/accounts/:id/refresh-tier
func (h *AccountHandler) RefreshTier(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
ctx := c.Request.Context()
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
return
}
oauthType, _ := account.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
return
}
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
if err != nil {
response.ErrorFrom(c, err)
return
}
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
})
if updateErr != nil {
response.ErrorFrom(c, updateErr)
return
}
response.Success(c, gin.H{
"tier_id": tierID,
"storage_info": extra,
"drive_storage_limit": extra["drive_storage_limit"],
"drive_storage_usage": extra["drive_storage_usage"],
"updated_at": extra["drive_tier_updated_at"],
})
}
// BatchRefreshTierRequest represents batch tier refresh request
type BatchRefreshTierRequest struct {
AccountIDs []int64 `json:"account_ids"`
}
// BatchRefreshTier handles batch refreshing Google One tier
// POST /api/v1/admin/accounts/batch-refresh-tier
func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
var req BatchRefreshTierRequest
if err := c.ShouldBindJSON(&req); err != nil {
req = BatchRefreshTierRequest{}
}
ctx := c.Request.Context()
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
if err != nil {
response.ErrorFrom(c, err)
return
}
for i := range allAccounts {
acc := &allAccounts[i]
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType == "google_one" {
accounts = append(accounts, acc)
}
}
} else {
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
for _, acc := range fetched {
if acc == nil {
continue
}
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
continue
}
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
continue
}
accounts = append(accounts, acc)
}
}
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
var successCount, failedCount int
var errors []gin.H
for _, account := range accounts {
acc := account // 闭包捕获
g.Go(func() error {
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
if err != nil {
mu.Lock()
failedCount++
errors = append(errors, gin.H{
"account_id": acc.ID,
"error": err.Error(),
})
mu.Unlock()
return nil
}
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
})
mu.Lock()
if updateErr != nil {
failedCount++
errors = append(errors, gin.H{
"account_id": acc.ID,
"error": updateErr.Error(),
})
} else {
successCount++
}
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
results := gin.H{
"total": len(accounts),
"success": successCount,
"failed": failedCount,
"errors": errors,
}
response.Success(c, results)
}

View File

@@ -46,8 +46,8 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if oauthType == "" { if oauthType == "" {
oauthType = "code_assist" oauthType = "code_assist"
} }
if oauthType != "code_assist" && oauthType != "ai_studio" { if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'") response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return return
} }
@@ -92,8 +92,8 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
if oauthType == "" { if oauthType == "" {
oauthType = "code_assist" oauthType = "code_assist"
} }
if oauthType != "code_assist" && oauthType != "ai_studio" { if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'") response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return return
} }

View File

@@ -246,7 +246,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
} }
// Limit to 30 results // Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword) users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -0,0 +1,342 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserAttributeHandler handles user attribute management
type UserAttributeHandler struct {
attrService *service.UserAttributeService
}
// NewUserAttributeHandler creates a new handler
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
return &UserAttributeHandler{attrService: attrService}
}
// --- Request/Response DTOs ---
// CreateAttributeDefinitionRequest represents create attribute definition request
type CreateAttributeDefinitionRequest struct {
Key string `json:"key" binding:"required,min=1,max=100"`
Name string `json:"name" binding:"required,min=1,max=255"`
Description string `json:"description"`
Type string `json:"type" binding:"required"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
Enabled bool `json:"enabled"`
}
// UpdateAttributeDefinitionRequest represents update attribute definition request
type UpdateAttributeDefinitionRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Type *string `json:"type"`
Options *[]service.UserAttributeOption `json:"options"`
Required *bool `json:"required"`
Validation *service.UserAttributeValidation `json:"validation"`
Placeholder *string `json:"placeholder"`
Enabled *bool `json:"enabled"`
}
// ReorderRequest represents reorder attribute definitions request
type ReorderRequest struct {
IDs []int64 `json:"ids" binding:"required"`
}
// UpdateUserAttributesRequest represents update user attributes request
type UpdateUserAttributesRequest struct {
Values map[int64]string `json:"values" binding:"required"`
}
// BatchGetUserAttributesRequest represents batch get user attributes request
type BatchGetUserAttributesRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// BatchUserAttributesResponse represents batch user attributes response
type BatchUserAttributesResponse struct {
// Map of userID -> map of attributeID -> value
Attributes map[int64]map[int64]string `json:"attributes"`
}
// AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Description string `json:"description"`
Type string `json:"type"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
DisplayOrder int `json:"display_order"`
Enabled bool `json:"enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// AttributeValueResponse represents attribute value response
type AttributeValueResponse struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
AttributeID int64 `json:"attribute_id"`
Value string `json:"value"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// --- Helpers ---
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
return &AttributeDefinitionResponse{
ID: def.ID,
Key: def.Key,
Name: def.Name,
Description: def.Description,
Type: string(def.Type),
Options: def.Options,
Required: def.Required,
Validation: def.Validation,
Placeholder: def.Placeholder,
DisplayOrder: def.DisplayOrder,
Enabled: def.Enabled,
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
return &AttributeValueResponse{
ID: val.ID,
UserID: val.UserID,
AttributeID: val.AttributeID,
Value: val.Value,
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
// --- Handlers ---
// ListDefinitions lists all attribute definitions
// GET /admin/user-attributes
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
enabledOnly := c.Query("enabled") == "true"
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeDefinitionResponse, 0, len(defs))
for i := range defs {
out = append(out, defToResponse(&defs[i]))
}
response.Success(c, out)
}
// CreateDefinition creates a new attribute definition
// POST /admin/user-attributes
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
var req CreateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
Key: req.Key,
Name: req.Name,
Description: req.Description,
Type: service.UserAttributeType(req.Type),
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// UpdateDefinition updates an attribute definition
// PUT /admin/user-attributes/:id
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
var req UpdateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := service.UpdateAttributeDefinitionInput{
Name: req.Name,
Description: req.Description,
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
}
if req.Type != nil {
t := service.UserAttributeType(*req.Type)
input.Type = &t
}
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// DeleteDefinition deletes an attribute definition
// DELETE /admin/user-attributes/:id
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
}
// ReorderDefinitions reorders attribute definitions
// PUT /admin/user-attributes/reorder
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
var req ReorderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Convert IDs array to orders map (position in array = display_order)
orders := make(map[int64]int, len(req.IDs))
for i, id := range req.IDs {
orders[id] = i
}
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Reorder successful"})
}
// GetUserAttributes gets a user's attribute values
// GET /admin/users/:id/attributes
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// UpdateUserAttributes updates a user's attribute values
// PUT /admin/users/:id/attributes
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
for attrID, value := range req.Values {
inputs = append(inputs, service.UpdateUserAttributeInput{
AttributeID: attrID,
Value: value,
})
}
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated values
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// GetBatchUserAttributes gets attribute values for multiple users
// POST /admin/user-attributes/batch
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
var req BatchGetUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
}

View File

@@ -27,7 +27,6 @@ type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"` Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"` Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"` Notes string `json:"notes"`
Balance float64 `json:"balance"` Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"` Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"` Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"` Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Balance *float64 `json:"balance"` Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
// List handles listing all users with pagination // List handles listing all users with pagination
// GET /api/v1/admin/users // GET /api/v1/admin/users
// Query params:
// - status: filter by user status
// - role: filter by user role
// - search: search in email, username
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
func (h *UserHandler) List(c *gin.Context) { func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search) filters := service.UserListFilters{
Status: c.Query("status"),
Role: c.Query("role"),
Search: c.Query("search"),
Attributes: parseAttributeFilters(c),
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
response.Paginated(c, out, total, page, pageSize) response.Paginated(c, out, total, page, pageSize)
} }
// parseAttributeFilters extracts attribute filters from query params
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
func parseAttributeFilters(c *gin.Context) map[int64]string {
result := make(map[int64]string)
// Get all query params and look for attr[*] pattern
for key, values := range c.Request.URL.Query() {
if len(values) == 0 || values[0] == "" {
continue
}
// Check if key matches pattern attr[{id}]
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
idStr := key[5 : len(key)-1]
id, err := strconv.ParseInt(idStr, 10, 64)
if err == nil && id > 0 {
result[id] = values[0]
}
}
}
return result
}
// GetByID handles getting a user by ID // GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id // GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) { func (h *UserHandler) GetByID(c *gin.Context) {
@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
Email: req.Email, Email: req.Email,
Password: req.Password, Password: req.Password,
Username: req.Username, Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes, Notes: req.Notes,
Balance: req.Balance, Balance: req.Balance,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
Email: req.Email, Email: req.Email,
Password: req.Password, Password: req.Password,
Username: req.Username, Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes, Notes: req.Notes,
Balance: req.Balance, Balance: req.Balance,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,

View File

@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes, Notes: u.Notes,
Role: u.Role, Role: u.Role,
Balance: u.Balance, Balance: u.Balance,

View File

@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Username string `json:"username"` Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"` Notes string `json:"notes"`
Role string `json:"role"` Role string `json:"role"`
Balance float64 `json:"balance"` Balance float64 `json:"balance"`

View File

@@ -142,6 +142,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil { } else if apiKey.Group != nil {
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
const maxAccountSwitches = 3 const maxAccountSwitches = 3
@@ -150,7 +154,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -159,9 +163,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream { if reqStream {
sendMockWarmupStream(c, reqModel) sendMockWarmupStream(c, reqModel)
} else { } else {
@@ -171,11 +179,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
log.Printf("Account concurrency acquire failed: %v", err) if !selection.Acquired {
h.handleConcurrencyError(c, err, "account", streamStarted) if selection.WaitPlan == nil {
return h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
@@ -188,6 +231,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
@@ -232,7 +278,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -241,9 +287,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream { if reqStream {
sendMockWarmupStream(c, reqModel) sendMockWarmupStream(c, reqModel)
} else { } else {
@@ -253,11 +303,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
log.Printf("Account concurrency acquire failed: %v", err) if !selection.Acquired {
h.handleConcurrencyError(c, err, "account", streamStarted) if selection.WaitPlan == nil {
return h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
@@ -270,6 +355,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
@@ -309,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Models handles listing available models // Models handles listing available models
// GET /v1/models // GET /v1/models
// Returns different model lists based on the API key's group platform or forced platform // Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) { func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetApiKeyFromContext(c) apiKey, _ := middleware2.GetApiKeyFromContext(c)
@@ -324,8 +413,37 @@ func (h *GatewayHandler) Models(c *gin.Context) {
} }
} }
// Return OpenAI models for OpenAI platform groups var groupID *int64
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" { var platform string
if apiKey != nil && apiKey.Group != nil {
groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
if len(availableModels) > 0 {
// Build model list from whitelist
models := make([]claude.Model, 0, len(availableModels))
for _, modelID := range availableModels {
models = append(models, claude.Model{
ID: modelID,
Type: "model",
DisplayName: modelID,
CreatedAt: "2024-01-01T00:00:00Z",
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
return
}
// Fallback to default models
if platform == "openai" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"object": "list", "object": "list",
"data": openai.DefaultModels, "data": openai.DefaultModels,
@@ -333,7 +451,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return return
} }
// Default: Claude models
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"object": "list", "object": "list",
"data": claude.DefaultModels, "data": claude.DefaultModels,

View File

@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID) h.concurrencyService.DecrementWaitCount(ctx, userID)
} }
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller). // streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel() defer cancel()
// Determine if ping is needed (streaming + ping format defined) // Determine if ping is needed (streaming + ping format defined)
@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
} }
} }
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间 // nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间 // current: 当前退避时间

View File

@@ -198,13 +198,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body) parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3 const maxAccountSwitches = 3
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -213,12 +217,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus) handleGeminiFailoverExhausted(c, lastFailoverStatus)
return return
} }
account := selection.Account
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
googleError(c, http.StatusTooManyRequests, err.Error()) if !selection.Acquired {
return if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
@@ -231,6 +271,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {

View File

@@ -20,6 +20,7 @@ type AdminHandlers struct {
System *admin.SystemHandler System *admin.SystemHandler
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers

View File

@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for { for {
// Select account supporting the requested model // Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
@@ -156,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) accountReleaseFunc := selection.ReleaseFunc
if err != nil { var accountWaitRelease func()
log.Printf("Account concurrency acquire failed: %v", err) if !selection.Acquired {
h.handleConcurrencyError(c, err, "account", streamStarted) if selection.WaitPlan == nil {
return h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} }
// Forward request // Forward request
@@ -171,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {

View File

@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload // UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Username *string `json:"username"` Username *string `json:"username"`
Wechat *string `json:"wechat"`
} }
// GetProfile handles getting user profile // GetProfile handles getting user profile
@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{ svcReq := service.UpdateProfileRequest{
Username: req.Username, Username: req.Username,
Wechat: req.Wechat,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {

View File

@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
systemHandler *admin.SystemHandler, systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
System: systemHandler, System: systemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Usage: usageHandler, Usage: usageHandler,
UserAttribute: userAttributeHandler,
} }
} }
@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
ProvideSystemHandler, ProvideSystemHandler,
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
admin.NewUsageHandler, admin.NewUsageHandler,
admin.NewUserAttributeHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,

View File

@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite", "gemini-2.5-flash-lite",
"gemini-3-flash", "gemini-3-flash",
"gemini-3-pro-low", "gemini-3-pro-low",
"gemini-3-pro-high",
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"]) t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
} }
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
"gemini-3-flash", // 直接支持
"gemini-3-pro-low", // 直接支持
"gemini-3-pro-high", // 直接支持
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
}
for i, model := range geminiViaClaude {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证Gemini 模型接受没有 signature 的 thinking block // 验证Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) { func TestClaudeMessagesWithNoSignature(t *testing.T) {
@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
} }
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"]) t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
} }
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
"claude-sonnet-4-5",
"claude-opus-4-5-thinking",
}
for i, model := range claudeViaGemini {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}

View File

@@ -54,6 +54,9 @@ type CustomToolSpec struct {
InputSchema map[string]any `json:"input_schema"` InputSchema map[string]any `json:"input_schema"`
} }
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素 // SystemBlock system prompt 数组形式的元素
type SystemBlock struct { type SystemBlock struct {
Type string `json:"type"` Type string `json:"type"`

View File

@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking除非未来支持完整签名链路
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents // 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil { if err != nil {
@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig // 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq) reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools // 4. 构建 tools
tools := buildTools(claudeReq.Tools) tools := buildTools(claudeReq.Tools)
@@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
if !hasThoughtPart && len(parts) > 0 { if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block // 在开头添加 dummy thinking block
parts = append([]GeminiPart{{ parts = append([]GeminiPart{{
Text: "Thinking...", Text: "Thinking...",
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...) }}, parts...)
} }
} }
@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator" const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts // buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
} }
case "thinking": case "thinking":
part := GeminiPart{ if allowDummyThought {
Text: block.Thinking, // Gemini 模型可以使用 dummy signature
Thought: true, parts = append(parts, GeminiPart{
} Text: block.Thinking,
// 保留原有 signatureClaude 模型需要有效的 signature Thought: true,
if block.Signature != "" { ThoughtSignature: dummyThoughtSignature,
part.ThoughtSignature = block.Signature })
} else if !allowDummyThought {
// Claude 模型需要有效 signature跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
continue continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part)
// Claude 模型:仅在提供有效 signature 时保留 thinking block否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image": case "image":
if block.Source != nil && block.Source.Type == "base64" { if block.Source != nil && block.Source.Type == "base64" {
@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID, ID: block.ID,
}, },
} }
// 保留原有 signature或对 Gemini 模型使用 dummy signature // 只有 Gemini 模型使用 dummy signature
if block.Signature != "" { // Claude 模型不设置 signature避免验证问题
part.ThoughtSignature = block.Signature if allowDummyThought {
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具 // 普通工具
var funcDecls []GeminiFunctionDecl var funcDecls []GeminiFunctionDecl
for _, tool := range tools { for i, tool := range tools {
// 跳过无效工具名称 // 跳过无效工具名称
if tool.Name == "" { if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name") log.Printf("Warning: skipping tool with empty name")
continue continue
} }
@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var inputSchema map[string]any var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP) // 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" && tool.Custom != nil { if tool.Type == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else { } else {
// 标准格式: 从顶层字段获取 // 标准格式: 从顶层字段获取
description = tool.Description description = tool.Description
@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema // 清理 JSON Schema
params := cleanJSONSchema(inputSchema) params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值 // 为 nil schema 提供默认值
if params == nil { if params == nil {
params = map[string]any{ params = map[string]any{
@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
} }
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{ funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name, Name: tool.Name,
Description: description, Description: description,
@@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
} }
// excludedSchemaKeys 不支持的 schema 字段 // excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{ var excludedSchemaKeys = map[string]bool{
"$schema": true, // 元 schema 字段
"$id": true, "$schema": true,
"$ref": true, "$id": true,
"additionalProperties": true, "$ref": true,
"minLength": true,
"maxLength": true, // 字符串验证Gemini 不支持)
"minItems": true, "minLength": true,
"maxItems": true, "maxLength": true,
"uniqueItems": true, "pattern": true,
"minimum": true,
"maximum": true, // 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"exclusiveMinimum": true, "minimum": true,
"exclusiveMaximum": true, "maximum": true,
"pattern": true, "exclusiveMinimum": true,
"format": true, "exclusiveMaximum": true,
"default": true, "multipleOf": true,
"strict": true,
"const": true, // 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"examples": true, "uniqueItems": true,
"deprecated": true, "minItems": true,
"readOnly": true, "maxItems": true,
"writeOnly": true,
"contentMediaType": true, // 组合 schemaGemini 不支持)
"contentEncoding": true, "oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
} }
// cleanSchemaValue 递归清理 schema 值 // cleanSchemaValue 递归清理 schema 值
@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue continue
} }
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值 // 递归清理所有值
result[k] = cleanSchemaValue(val) result[k] = cleanSchemaValue(val)
} }

View File

@@ -0,0 +1,179 @@
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}

View File

@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头 // Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", "User-Agent": "claude-cli/2.0.62 (external, cli)",

View File

@@ -0,0 +1,157 @@
package geminicli
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
// DriveStorageInfo represents Google Drive storage quota information
type DriveStorageInfo struct {
Limit int64 `json:"limit"` // Storage limit in bytes
Usage int64 `json:"usage"` // Current usage in bytes
}
// DriveClient interface for Google Drive API operations
type DriveClient interface {
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
}
type driveClient struct{}
// NewDriveClient creates a new Drive API client
func NewDriveClient() DriveClient {
return &driveClient{}
}
// GetStorageQuota fetches storage quota from Google Drive API
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Get HTTP client with proxy support
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
sleepWithContext := func(d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var resp *http.Response
maxRetries := 3
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for attempt := 0; attempt < maxRetries; attempt++ {
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
}
resp, err = client.Do(req)
if err != nil {
// Network error retry
if attempt < maxRetries-1 {
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
}
// Success
if resp.StatusCode == http.StatusOK {
break
}
// Retry 429, 500, 502, 503 with exponential backoff + jitter
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusInternalServerError ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
if err := func() error {
defer func() { _ = resp.Body.Close() }()
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
return sleepWithContext(backoff + jitter)
}(); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
break
}
if resp == nil {
return nil, fmt.Errorf("request failed: no response received")
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
statusText := http.StatusText(resp.StatusCode)
if statusText == "" {
statusText = resp.Status
}
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
// 只返回通用错误
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
// Parse response
var result struct {
StorageQuota struct {
Limit string `json:"limit"` // Can be string or number
Usage string `json:"usage"`
} `json:"storageQuota"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Parse limit and usage (handle both string and number formats)
var limit, usage int64
if result.StorageQuota.Limit != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
limit = val
}
}
if result.StorageQuota.Usage != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
usage = val
}
}
return &DriveStorageInfo{
Limit: limit,
Usage: usage,
}, nil
}

View File

@@ -0,0 +1,18 @@
package geminicli
import "testing"
func TestDriveStorageInfo(t *testing.T) {
// 测试 DriveStorageInfo 结构体
info := &DriveStorageInfo{
Limit: 100 * 1024 * 1024 * 1024, // 100GB
Usage: 50 * 1024 * 1024 * 1024, // 50GB
}
if info.Limit != 100*1024*1024*1024 {
t.Errorf("Expected limit 100GB, got %d", info.Limit)
}
if info.Usage != 50*1024*1024*1024 {
t.Errorf("Expected usage 50GB, got %d", info.Usage)
}
}

View File

@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
return &accounts[0], nil return &accounts[0], nil
} }
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
if len(ids) == 0 {
return []*service.Account{}, nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return []*service.Account{}, nil
}
entAccounts, err := r.client.Account.
Query().
Where(dbaccount.IDIn(uniqueIDs...)).
WithProxy().
All(ctx)
if err != nil {
return nil, err
}
if len(entAccounts) == 0 {
return []*service.Account{}, nil
}
accountIDs := make([]int64, 0, len(entAccounts))
entByID := make(map[int64]*dbent.Account, len(entAccounts))
for _, acc := range entAccounts {
entByID[acc.ID] = acc
accountIDs = append(accountIDs, acc.ID)
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outByID := make(map[int64]*service.Account, len(entAccounts))
for _, entAcc := range entAccounts {
out := accountEntityToService(entAcc)
if out == nil {
continue
}
// Prefer the preloaded proxy edge when available.
if entAcc.Edges.Proxy != nil {
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
}
if groups, ok := groupsByAccount[entAcc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
outByID[entAcc.ID] = out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out := make([]*service.Account, 0, len(uniqueIDs))
for _, id := range uniqueIDs {
if _, ok := entByID[id]; !ok {
continue
}
if acc, ok := outByID[id]; ok && acc != nil {
out = append(out, acc)
}
}
return out, nil
}
// ExistsByID 检查指定 ID 的账号是否存在。 // ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID此方法性能更优因为 // 相比 GetByID此方法性能更优因为
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值 // - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值

View File

@@ -294,7 +294,6 @@ func userEntityToService(u *dbent.User) *service.User {
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes, Notes: u.Notes,
PasswordHash: u.PasswordHash, PasswordHash: u.PasswordHash,
Role: u.Role, Role: u.Role,

View File

@@ -2,7 +2,9 @@ package repository
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -27,6 +29,8 @@ const (
userSlotKeyPrefix = "concurrency:user:" userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID} // 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:" waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖 // 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15 defaultSlotTTLMinutes = 15
@@ -112,33 +116,112 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2]) redis.call('EXPIRE', KEYS[1], ARGV[2])
end end
return 1 return 1
`) `)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before // decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(` decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1]) redis.call('DECR', KEYS[1])
end end
return 1 return 1
`) `)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
) )
type concurrencyCache struct { type concurrencyCache struct {
rdb *redis.Client rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒) slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
} }
// NewConcurrencyCache 创建并发控制缓存 // NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟 // slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { // waitQueueTTLSeconds: 等待队列过期时间0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 { if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes slotTTLMinutes = defaultSlotTTLMinutes
} }
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{ return &concurrencyCache{
rdb: rdb, rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60, slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
} }
} }
@@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
} }
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations // Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
@@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err return err
} }
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}

View File

@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_ = rdb.Close() _ = rdb.Close()
}() }()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background() ctx := context.Background()
for _, size := range []int{10, 100, 1000} { for _, size := range []int{10, 100, 1000} {

View File

@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0 // When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur) require.Equal(s.T(), 0, cur)
} }
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) { func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite)) suite.Run(t, new(ConcurrencyCacheSuite))
} }

View File

@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic
SetBalance(u.Balance). SetBalance(u.Balance).
SetConcurrency(u.Concurrency). SetConcurrency(u.Concurrency).
SetUsername(u.Username). SetUsername(u.Username).
SetWechat(u.Wechat).
SetNotes(u.Notes) SetNotes(u.Notes)
if !u.CreatedAt.IsZero() { if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt) create.SetCreatedAt(u.CreatedAt)

View File

@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum { if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。 // 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum) return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
} }
continue // 迁移已应用且校验和匹配,跳过 continue // 迁移已应用且校验和匹配,跳过
} }

View File

@@ -23,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// users: columns required by repository queries // users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false) requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false) requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields // accounts: schedulable and rate-limit fields

View File

@@ -0,0 +1,385 @@
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// UserAttributeDefinitionRepository implementation
type userAttributeDefinitionRepository struct {
client *dbent.Client
}
// NewUserAttributeDefinitionRepository creates a new repository instance
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
return &userAttributeDefinitionRepository{client: client}
}
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
created, err := client.UserAttributeDefinition.Create().
SetKey(def.Key).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
}
def.ID = created.ID
def.DisplayOrder = created.DisplayOrder
def.CreatedAt = created.CreatedAt
def.UpdatedAt = created.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetDisplayOrder(def.DisplayOrder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
}
def.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeDefinition.Delete().
Where(userattributedefinition.IDEQ(id)).
Exec(ctx)
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
q := client.UserAttributeDefinition.Query()
if enabledOnly {
q = q.Where(userattributedefinition.EnabledEQ(true))
}
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeDefinition, 0, len(entities))
for _, e := range entities {
result = append(result, *defEntityToService(e))
}
return result, nil
}
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for id, order := range orders {
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
SetDisplayOrder(order).
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
}
return tx.Commit()
}
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
client := clientFromContext(ctx, r.client)
return client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Exist(ctx)
}
// UserAttributeValueRepository implementation
type userAttributeValueRepository struct {
client *dbent.Client
}
// NewUserAttributeValueRepository creates a new repository instance
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
return &userAttributeValueRepository{client: client}
}
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDEQ(userID)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
if len(userIDs) == 0 {
return []service.UserAttributeValue{}, nil
}
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
if len(inputs) == 0 {
return nil
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, input := range inputs {
// Use upsert (ON CONFLICT DO UPDATE)
err := tx.UserAttributeValue.Create().
SetUserID(userID).
SetAttributeID(input.AttributeID).
SetValue(input.Value).
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
UpdateValue().
UpdateUpdatedAt().
Exec(ctx)
if err != nil {
return err
}
}
return tx.Commit()
}
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.AttributeIDEQ(attributeID)).
Exec(ctx)
return err
}
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.UserIDEQ(userID)).
Exec(ctx)
return err
}
// Helper functions for entity to service conversion
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
if e == nil {
return nil
}
return &service.UserAttributeDefinition{
ID: e.ID,
Key: e.Key,
Name: e.Name,
Description: e.Description,
Type: service.UserAttributeType(e.Type),
Options: toServiceOptions(e.Options),
Required: e.Required,
Validation: toServiceValidation(e.Validation),
Placeholder: e.Placeholder,
DisplayOrder: e.DisplayOrder,
Enabled: e.Enabled,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
}
// Type conversion helpers (map types <-> service types)
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
if opts == nil {
return []map[string]any{}
}
result := make([]map[string]any, len(opts))
for i, o := range opts {
result[i] = map[string]any{"value": o.Value, "label": o.Label}
}
return result
}
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
if opts == nil {
return []service.UserAttributeOption{}
}
result := make([]service.UserAttributeOption, len(opts))
for i, o := range opts {
result[i] = service.UserAttributeOption{
Value: getString(o, "value"),
Label: getString(o, "label"),
}
}
return result
}
func toEntValidation(v service.UserAttributeValidation) map[string]any {
result := map[string]any{}
if v.MinLength != nil {
result["min_length"] = *v.MinLength
}
if v.MaxLength != nil {
result["max_length"] = *v.MaxLength
}
if v.Min != nil {
result["min"] = *v.Min
}
if v.Max != nil {
result["max"] = *v.Max
}
if v.Pattern != nil {
result["pattern"] = *v.Pattern
}
if v.Message != nil {
result["message"] = *v.Message
}
return result
}
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
result := service.UserAttributeValidation{}
if val := getInt(v, "min_length"); val != nil {
result.MinLength = val
}
if val := getInt(v, "max_length"); val != nil {
result.MaxLength = val
}
if val := getInt(v, "min"); val != nil {
result.Min = val
}
if val := getInt(v, "max"); val != nil {
result.Max = val
}
if val := getStringPtr(v, "pattern"); val != nil {
result.Pattern = val
}
if val := getStringPtr(v, "message"); val != nil {
result.Message = val
}
return result
}
// Helper functions for type conversion
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getStringPtr(m map[string]any, key string) *string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return &s
}
}
return nil
}
func getInt(m map[string]any, key string) *int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case int:
return &n
case int64:
i := int(n)
return &i
case float64:
i := int(n)
return &i
}
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
created, err := txClient.User.Create(). created, err := txClient.User.Create().
SetEmail(userIn.Email). SetEmail(userIn.Email).
SetUsername(userIn.Username). SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes). SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash). SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role). SetRole(userIn.Role).
@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
updated, err := txClient.User.UpdateOneID(userIn.ID). updated, err := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email). SetEmail(userIn.Email).
SetUsername(userIn.Username). SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes). SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash). SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role). SetRole(userIn.Role).
@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, service.UserListFilters{})
} }
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query() q := r.client.User.Query()
if status != "" { if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(status)) q = q.Where(dbuser.StatusEQ(filters.Status))
} }
if role != "" { if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(role)) q = q.Where(dbuser.RoleEQ(filters.Role))
} }
if search != "" { if filters.Search != "" {
q = q.Where( q = q.Where(
dbuser.Or( dbuser.Or(
dbuser.EmailContainsFold(search), dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(search), dbuser.UsernameContainsFold(filters.Search),
dbuser.WechatContainsFold(search),
), ),
) )
} }
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil return outUsers, paginationResultFromTotal(int64(total), params), nil
} }
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
if len(attrs) == 0 {
return nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
result = append(result, userID)
}
return result
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)

View File

@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive}) s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled}) s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(service.StatusActive, users[0].Status) s.Require().Equal(service.StatusActive, users[0].Status)
@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser}) s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin}) s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(service.RoleAdmin, users[0].Role) s.Require().Equal(service.RoleAdmin, users[0].Role)
@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"}) s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"}) s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice") s.Require().Contains(users[0].Email, "alice")
@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"}) s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"}) s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username) s.Require().Equal("JohnDoe", users[0].Username)
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive}) user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active") groupActive := s.mustCreateGroup("g-sub-active")
@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
c.SetExpiresAt(time.Now().Add(-1 * time.Hour)) c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
}) })
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user") s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription") s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s.mustCreateUser(&service.User{ s.mustCreateUser(&service.User{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser, Role: service.RoleUser,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
target := s.mustCreateUser(&service.User{ target := s.mustCreateUser(&service.User{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin, Role: service.RoleAdmin,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status: service.StatusDisabled, Status: service.StatusDisabled,
}) })
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := s.mustCreateUser(&service.User{ user1 := s.mustCreateUser(&service.User{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser, Role: service.RoleUser,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user2 := s.mustCreateUser(&service.User{ user2 := s.mustCreateUser(&service.User{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin, Role: service.RoleAdmin,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency) s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10} params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")

View File

@@ -15,7 +15,14 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化TTL 可配置,支持长时间运行的 LLM 请求场景 // 性能优化TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
} }
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
@@ -29,6 +36,8 @@ var ProviderSet = wire.NewSet(
NewUsageLogRepository, NewUsageLogRepository,
NewSettingRepository, NewSettingRepository,
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,

View File

@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1, "id": 1,
"email": "alice@example.com", "email": "alice@example.com",
"username": "alice", "username": "alice",
"wechat": "wx_alice",
"notes": "hello", "notes": "hello",
"role": "user", "role": "user",
"balance": 12.5, "balance": 12.5,
@@ -348,7 +347,6 @@ func newContractDeps(t *testing.T) *contractDeps {
ID: 1, ID: 1,
Email: "alice@example.com", Email: "alice@example.com",
Username: "alice", Username: "alice",
Wechat: "wx_alice",
Notes: "hello", Notes: "hello",
Role: service.RoleUser, Role: service.RoleUser,
Balance: 12.5, Balance: 12.5,
@@ -503,7 +501,7 @@ func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationPar
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) { func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }

View File

@@ -54,6 +54,9 @@ func RegisterAdminRoutes(
// 使用记录管理 // 使用记录管理
registerUsageRoutes(admin, h) registerUsageRoutes(admin, h)
// 用户属性管理
registerUserAttributeRoutes(admin, h)
} }
} }
@@ -82,6 +85,10 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.POST("/:id/balance", h.Admin.User.UpdateBalance) users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/usage", h.Admin.User.GetUserUsage)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
users.PUT("/:id/attributes", h.Admin.UserAttribute.UpdateUserAttributes)
} }
} }
@@ -110,6 +117,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test) accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh) accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
accounts.GET("/:id/stats", h.Admin.Account.GetStats) accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
@@ -119,6 +127,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes // Claude OAuth routes
@@ -242,3 +251,15 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
} }
} }
func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs := admin.Group("/user-attributes")
{
attrs.GET("", h.Admin.UserAttribute.ListDefinitions)
attrs.POST("", h.Admin.UserAttribute.CreateDefinition)
attrs.POST("/batch", h.Admin.UserAttribute.GetBatchUserAttributes)
attrs.PUT("/reorder", h.Admin.UserAttribute.ReorderDefinitions)
attrs.PUT("/:id", h.Admin.UserAttribute.UpdateDefinition)
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings"
"time" "time"
) )
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini return a.Platform == PlatformGemini
} }
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool { func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth return a.Type == AccountTypeOAuth
} }

View File

@@ -17,6 +17,9 @@ var (
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *Account) error Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error) GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查 // ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error) ExistsByID(ctx context.Context, id int64) (bool, error)
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.

View File

@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
panic("unexpected GetByIDs call")
}
// ExistsByID 返回预设的存在性检查结果。 // ExistsByID 返回预设的存在性检查结果。
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。 // 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) { func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {

View File

@@ -93,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
} }
// ClaudeUsageResponse Anthropic API返回的usage结构 // ClaudeUsageResponse Anthropic API返回的usage结构
@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService { func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher, usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
} }
} }
@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err) return nil, fmt.Errorf("get account failed: %w", err)
} }
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope // 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() { if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse var apiResp *ClaudeUsageResponse
@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type) return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
} }
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计 // addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离 // 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据 // Setup Token无法获取7d数据
return info return info
} }
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}

View File

@@ -13,7 +13,7 @@ import (
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
@@ -35,6 +35,7 @@ type AdminService interface {
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
@@ -69,7 +70,6 @@ type CreateUserInput struct {
Email string Email string
Password string Password string
Username string Username string
Wechat string
Notes string Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
@@ -80,7 +80,6 @@ type UpdateUserInput struct {
Email string Email string
Password string Password string
Username *string Username *string
Wechat *string
Notes *string Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0" Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
@@ -251,9 +250,9 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
user := &User{ user := &User{
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes, Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if input.Username != nil { if input.Username != nil {
user.Username = *input.Username user.Username = *input.Username
} }
if input.Wechat != nil {
user.Wechat = *input.Wechat
}
if input.Notes != nil { if input.Notes != nil {
user.Notes = *input.Notes user.Notes = *input.Notes
} }
@@ -611,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return s.accountRepo.GetByID(ctx, id) return s.accountRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,

View File

@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
Email: "user@test.com", Email: "user@test.com",
Password: "strong-pass", Password: "strong-pass",
Username: "tester", Username: "tester",
Wechat: "wx",
Notes: "note", Notes: "note",
Balance: 12.5, Balance: 12.5,
Concurrency: 7, Concurrency: 7,
@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
require.Equal(t, int64(10), user.ID) require.Equal(t, int64(10), user.ID)
require.Equal(t, input.Email, user.Email) require.Equal(t, input.Email, user.Email)
require.Equal(t, input.Username, user.Username) require.Equal(t, input.Username, user.Username)
require.Equal(t, input.Wechat, user.Wechat)
require.Equal(t, input.Notes, user.Notes) require.Equal(t, input.Notes, user.Notes)
require.Equal(t, input.Balance, user.Balance) require.Equal(t, input.Balance, user.Balance)
require.Equal(t, input.Concurrency, user.Concurrency) require.Equal(t, input.Concurrency, user.Concurrency)

View File

@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar
panic("unexpected List call") panic("unexpected List call")
} }
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }

View File

@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
// Antigravity 直接支持的模型 // Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{ var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true, "claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true, "claude-sonnet-4-5": true,
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true, "gemini-3-flash": true,
"gemini-3-pro-low": true, "gemini-3-pro-low": true,
"gemini-3-pro-high": true, "gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true, "gemini-3-pro-image": true,
} }
// Antigravity 系统默认模型映射表(不支持 → 支持 // Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先
var antigravityModelMapping = map[string]string{ // 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5", var antigravityPrefixMapping = []struct {
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5", prefix string
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking", target string
"claude-opus-4": "claude-opus-4-5-thinking", }{
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking", // 长前缀优先
"claude-haiku-4": "gemini-3-flash", {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
"claude-haiku-4-5": "gemini-3-flash", {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
"claude-3-haiku-20240307": "gemini-3-flash", {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
"claude-haiku-4-5-20251001": "gemini-3-flash", {"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
// 生图模型:官方名 → Antigravity 内部名 {"claude-opus-4-5", "claude-opus-4-5-thinking"},
"gemini-3-pro-image-preview": "gemini-3-pro-image", {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
} }
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 // AntigravityGatewayService 处理 Antigravity 平台的 API 转发
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
} }
// getMappedModel 获取映射后的模型名 // getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法 // 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel { if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped return mapped
} }
// 2. 系统默认映射 // 2. 直接支持的模型透传
if mapped, ok := antigravityModelMapping[requestedModel]; ok { if antigravitySupportedModels[requestedModel] {
return mapped
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel return requestedModel
} }
// 4. Claude 前缀透传直接支持的模型 // 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview
if antigravitySupportedModels[requestedModel] { for _, pm := range antigravityPrefixMapping {
if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
}
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel return requestedModel
} }
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
} }
// IsModelSupported 检查模型是否被支持 // IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool { func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型 return strings.HasPrefix(requestedModel, "claude-") ||
if antigravitySupportedModels[requestedModel] { strings.HasPrefix(requestedModel, "gemini-")
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
} }
// TestConnectionResult 测试连接结果 // TestConnectionResult 测试连接结果
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {

View File

@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name: "系统映射 - claude-sonnet-4-5-20250929", name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929", requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5-thinking", expected: "claude-sonnet-4-5",
}, },
// 3. Gemini 透传 // 3. Gemini 透传

View File

@@ -18,6 +18,11 @@ type ConcurrencyCache interface {
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理 // 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
@@ -27,6 +32,12 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL // 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
} }
// generateRequestID generates a unique request ID for concurrency slot tracking // generateRequestID generates a unique request ID for concurrency slot tracking
@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer) ReleaseFunc func() // Must be called when done (typically via defer)
} }
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account. // AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout. // If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes. // Returns a release function that MUST be called when the request completes.
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
} }
} }
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user // CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots // maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int { func CalculateMaxWait(userConcurrency int) int {
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots return userConcurrency + defaultExtraWaitSlots
} }
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count // Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {

View File

@@ -91,6 +91,9 @@ const (
// 管理员 API Key // 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成 SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
) )
// Admin API Key prefix (distinct from user "sk-" keys) // Admin API Key prefix (distinct from user "sk-" keys)

View File

@@ -32,6 +32,16 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac
return nil, errors.New("account not found") return nil, errors.New("account not found")
} }
func (m *mockAccountRepoForPlatform) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
var result []*Account
for _, id := range ids {
if acc, ok := m.accountsByID[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) { func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil { if m.accountsByID == nil {
return false, nil return false, nil
@@ -261,6 +271,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
} }
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@@ -576,6 +614,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
@@ -783,3 +847,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
}) })
} }
} }
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}

View File

@@ -13,12 +13,14 @@ import (
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
"sort"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -66,6 +68,20 @@ type GatewayCache interface {
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
} }
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息 // ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
@@ -108,6 +124,7 @@ type GatewayService struct {
identityService *IdentityService identityService *IdentityService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
deferredService *DeferredService deferredService *DeferredService
concurrencyService *ConcurrencyService
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
@@ -119,6 +136,7 @@ func NewGatewayService(
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService, billingService *BillingService,
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
@@ -134,6 +152,7 @@ func NewGatewayService(
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService, billingService: billingService,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
@@ -183,6 +202,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return "" return ""
} }
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil { if parsed == nil {
return "" return ""
@@ -332,8 +359,354 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -389,7 +762,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
case acc.LastUsedAt != nil && selected.LastUsedAt == nil: case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred) // keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil: case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used) if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default: default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) { if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc selected = acc
@@ -419,6 +794,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity} platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
@@ -478,7 +854,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
case acc.LastUsedAt != nil && selected.LastUsedAt == nil: case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred) // keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil: case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used) if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default: default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) { if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc selected = acc
@@ -515,24 +893,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
} }
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 // IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool { func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型 return strings.HasPrefix(requestedModel, "claude-") ||
if antigravitySupportedModels[requestedModel] { strings.HasPrefix(requestedModel, "gemini-")
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
} }
// GetAccessToken 获取账号凭证 // GetAccessToken 获取账号凭证
@@ -684,6 +1048,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
@@ -786,6 +1174,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理 // 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
@@ -838,6 +1233,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.DefaultBetaHeader return claude.DefaultBetaHeader
} }
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
@@ -850,6 +1322,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode { switch resp.StatusCode {
case 400: case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body) c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401: case 401:
@@ -1329,6 +1811,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态429/529等 // 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应 // 返回简化的错误响应
errMsg := "Upstream request failed" errMsg := "Upstream request failed"
switch resp.StatusCode { switch resp.StatusCode {
@@ -1409,6 +1903,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
@@ -1424,3 +1925,58 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}, },
}) })
} }
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil || len(accounts) == 0 {
return nil
}
// Filter by platform if specified
if platform != "" {
filtered := make([]Account, 0)
for _, acc := range accounts {
if acc.Platform == platform {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
// Collect unique models from all accounts
modelSet := make(map[string]struct{})
hasAnyMapping := false
for _, acc := range accounts {
mapping := acc.GetModelMapping()
if len(mapping) > 0 {
hasAnyMapping = true
for model := range mapping {
modelSet[model] = struct{}{}
}
}
}
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
return nil
}
// Convert to slice
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
return models
}

View File

@@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
valid = true valid = true
} }
if valid { if valid {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) usable := true
return account, nil if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
} }
} }
} }
@@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil { if selected == nil {
selected = acc selected = acc
continue continue
@@ -1886,13 +1907,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if statusCode != 429 { if statusCode != 429 {
return return
} }
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body) resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil { if resetAt == nil {
ra := time.Now().Add(5 * time.Minute) // 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return return
} }
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
} }
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳 // ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
@@ -1948,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
} }
func nextGeminiDailyResetUnix() *int64 { func nextGeminiDailyResetUnix() *int64 {
loc, err := time.LoadLocation("America/Los_Angeles") reset := geminiDailyResetTime(time.Now())
if err != nil {
// Fallback: PST without DST.
loc = time.FixedZone("PST", -8*3600)
}
now := time.Now().In(loc)
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
if !reset.After(now) {
reset = reset.Add(24 * time.Hour)
}
ts := reset.Unix() ts := reset.Unix()
return &ts return &ts
} }
@@ -2278,11 +2321,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties": map[string]any{}, "properties": map[string]any{},
} }
} }
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{ funcDecls = append(funcDecls, map[string]any{
"name": name, "name": name,
"description": desc, "description": desc,
"parameters": params, "parameters": cleanedParams,
}) })
} }
@@ -2296,6 +2341,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
} }
} }
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any { func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any) out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {

View File

@@ -0,0 +1,128 @@
package service
import (
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
name string
tools any
expectedLen int
description string
}{
{
name: "Standard tools",
tools: []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []any{
map[string]any{
"type": "custom",
"name": "mcp_tool",
"custom": map[string]any{
"description": "MCP tool description",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从custom字段读取",
},
{
name: "Mixed standard and custom tools",
tools: []any{
map[string]any{
"name": "standard_tool",
"description": "Standard",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "custom",
"name": "custom_tool",
"custom": map[string]any{
"description": "Custom",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "混合工具应该都能正确转换",
},
{
name: "Custom tool without custom field",
tools: []any{
map[string]any{
"type": "custom",
"name": "invalid_custom",
// 缺少 custom 字段
},
},
expectedLen: 0, // 应该被跳过
description: "缺少custom字段的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertClaudeToolsToGeminiTools(tt.tools)
if tt.expectedLen == 0 {
if result != nil {
t.Errorf("%s: expected nil result, got %v", tt.description, result)
}
return
}
if result == nil {
t.Fatalf("%s: expected non-nil result", tt.description)
}
if len(result) != 1 {
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
return
}
toolDecl, ok := result[0].(map[string]any)
if !ok {
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
}
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
if !ok {
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
}
toolsArr, _ := tt.tools.([]any)
expectedFuncCount := 0
for _, tool := range toolsArr {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] != "" {
// 检查是否为有效的custom工具
if toolMap["type"] == "custom" {
if toolMap["custom"] != nil {
expectedFuncCount++
}
} else {
expectedFuncCount++
}
}
}
if len(funcDecls) != expectedFuncCount {
t.Errorf("%s: expected %d function declarations, got %d",
tt.description, expectedFuncCount, len(funcDecls))
}
})
}
}

View File

@@ -25,6 +25,16 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco
return nil, errors.New("account not found") return nil, errors.New("account not found")
} }
func (m *mockAccountRepoForGemini) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
var result []*Account
for _, id := range ids {
if acc, ok := m.accountsByID[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) { func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil { if m.accountsByID == nil {
return false, nil return false, nil

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -16,6 +17,26 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
) )
const (
TierAIPremium = "AI_PREMIUM"
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
TierFree = "FREE"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
)
const (
GB = 1024 * 1024 * 1024
TB = 1024 * GB
StorageTierUnlimited = 100 * TB // 100TB
StorageTierAIPremium = 2 * TB // 2TB
StorageTierStandard = 200 * GB // 200GB
StorageTierBasic = 100 * GB // 100GB
StorageTierFree = 15 * GB // 15GB
)
type GeminiOAuthService struct { type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository proxyRepo ProxyRepository
@@ -88,13 +109,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// OAuth client selection: // OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
// - ai_studio: requires a user-provided OAuth client. // - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{ oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID, ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes, Scopes: s.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
oauthCfg.ClientID = "" oauthCfg.ClientID = ""
oauthCfg.ClientSecret = "" oauthCfg.ClientSecret = ""
} }
@@ -155,14 +177,152 @@ type GeminiExchangeCodeInput struct {
} }
type GeminiTokenInfo struct { type GeminiTokenInfo struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"` ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"` Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
}
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
if storageBytes <= 0 {
return TierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited
}
if storageBytes >= StorageTierAIPremium {
return TierAIPremium
}
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
}
if storageBytes >= StorageTierFree {
return TierFree
}
return TierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
// Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
tierID := inferGoogleOneTier(storageInfo.Limit)
return tierID, storageInfo, nil
}
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
ctx context.Context,
account *Account,
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
if account == nil {
return "", nil, nil, fmt.Errorf("account is nil")
}
// 验证账号类型
oauthType, ok := account.Credentials["oauth_type"].(string)
if !ok || oauthType != "google_one" {
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
}
// 获取 access_token
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
return "", nil, nil, fmt.Errorf("missing access_token")
}
// 获取 proxy URL
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 调用 Drive API
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
if err != nil {
return "", nil, nil, err
}
// 构建 extra 数据(保留原有 extra 字段)
extra = make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if storageInfo != nil {
extra["drive_storage_limit"] = storageInfo.Limit
extra["drive_storage_usage"] = storageInfo.Usage
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
}
// 构建 credentials 数据
credentials = make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["tier_id"] = tierID
return tierID, extra, credentials, nil
} }
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
@@ -219,26 +379,78 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
sessionProjectID := strings.TrimSpace(session.ProjectID) sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID) s.sessionStore.Delete(input.SessionID)
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 // 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
projectID := sessionProjectID projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式project_id 是必需的 // 对于 code_assist 模式project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id配额由 Google 网关自动识别
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API // 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" { switch oauthType {
case "code_assist":
if projectID == "" { if projectID == "" {
var err error var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil { if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id // 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
} }
} else {
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
} else {
tierID = fetchedTierID
}
} }
if strings.TrimSpace(projectID) == "" { if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
} }
// tierID 缺失时使用默认值
if tierID == "" {
tierID = "LEGACY"
}
case "google_one":
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
}
// Store Drive info in extra field for caching
if storageInfo != nil {
tokenInfo := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
Extra: map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
return tokenInfo, nil
}
} }
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{ return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken, AccessToken: tokenResp.AccessToken,
@@ -248,6 +460,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
Scope: tokenResp.Scope, Scope: tokenResp.Scope,
ProjectID: projectID, ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType, OAuthType: oauthType,
}, nil }, nil
} }
@@ -266,8 +479,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL) tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil { if err == nil {
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 // 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
return &GeminiTokenInfo{ return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken, AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken, RefreshToken: tokenResp.RefreshToken,
@@ -354,18 +574,75 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
tokenInfo.ProjectID = existingProjectID tokenInfo.ProjectID = existingProjectID
} }
// 尝试从账号凭证获取 tierID向后兼容
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
// For Code Assist, project_id is required. Auto-detect if missing. // For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh. // For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { switch oauthType {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) case "code_assist":
if err != nil { // 先设置默认值或保留旧值,确保 tier_id 始终有值
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
} }
projectID = strings.TrimSpace(projectID)
if projectID == "" { // 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
} else {
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
}
}
}
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result") return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
} }
tokenInfo.ProjectID = projectID case "google_one":
// Check if tier cache is stale (> 24 hours)
needsRefresh := true
if account.Extra != nil {
if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok {
if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil {
if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false
// Use cached tier
if existingTierID != "" {
tokenInfo.TierID = existingTierID
}
}
}
}
}
if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil {
tokenInfo.TierID = tierID
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
}
} else {
// Fallback to cached or unknown
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = TierGoogleOneUnknown
}
}
}
} }
return tokenInfo, nil return tokenInfo, nil
@@ -388,9 +665,22 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" { if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID creds["project_id"] = tokenInfo.ProjectID
} }
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" { if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType creds["oauth_type"] = tokenInfo.OAuthType
} }
// Store extra metadata (Drive info) if present
if len(tokenInfo.Extra) > 0 {
for k, v := range tokenInfo.Extra {
creds[k] = v
}
}
return creds return creds
} }
@@ -398,33 +688,22 @@ func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop() s.sessionStore.Stop()
} }
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil { if s.codeAssist == nil {
return "", errors.New("code assist client not configured") return "", "", errors.New("code assist client not configured")
} }
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. // Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY" tierID := "LEGACY"
if loadResp != nil { if loadResp != nil {
for _, tier := range loadResp.AllowedTiers { tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { }
tierID = strings.TrimSpace(tier.ID)
break // If LoadCodeAssist returned a project, use it
} if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
} return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
} }
req := &geminicli.OnboardUserRequest{ req := &geminicli.OnboardUserRequest{
@@ -443,39 +722,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
return "", err return "", tierID, err
} }
if resp.Done { if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) { switch v := resp.Response.CloudAICompanionProject.(type) {
case string: case string:
return strings.TrimSpace(v), nil return strings.TrimSpace(v), tierID, nil
case map[string]any: case map[string]any:
if id, ok := v["id"].(string); ok { if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil return strings.TrimSpace(id), tierID, nil
} }
} }
} }
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
return "", errors.New("onboardUser completed but no project_id returned") return "", tierID, errors.New("onboardUser completed but no project_id returned")
} }
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
if loadErr != nil { if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
} }
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
} }
type googleCloudProject struct { type googleCloudProject struct {

View File

@@ -0,0 +1,51 @@
package service
import "testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
}
})
}
}

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type geminiModelClass string
const (
geminiModelPro geminiModelClass = "pro"
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Cooldown time.Duration
}
type GeminiQuotaPolicy struct {
tiers map[string]GeminiTierPolicy
}
type GeminiUsageTotals struct {
ProRequests int64
FlashRequests int64
ProTokens int64
FlashTokens int64
ProCost float64
FlashCost float64
}
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
mu sync.Mutex
cachedAt time.Time
policy *GeminiQuotaPolicy
}
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
return &GeminiQuotaService{
cfg: cfg,
settingRepo: settingRepo,
}
}
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s == nil {
return newGeminiQuotaPolicy()
}
now := time.Now()
s.mu.Lock()
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
policy := s.policy
s.mu.Unlock()
return policy
}
s.mu.Unlock()
policy := newGeminiQuotaPolicy()
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
if s.settingRepo != nil {
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
s.mu.Lock()
s.policy = policy
s.cachedAt = now
s.mu.Unlock()
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
policy := s.Policy(ctx)
return policy.CooldownForTier(tierID)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
},
}
}
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
if p == nil || len(tiers) == 0 {
return
}
for rawID, override := range tiers {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
policy.Cooldown = time.Duration(minutes) * time.Minute
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
}
return policy.Quota, true
}
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
policy, ok := p.policyForTier(tierID)
if ok && policy.Cooldown > 0 {
return policy.Cooldown
}
return 5 * time.Minute
}
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
if p == nil {
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func clampGeminiQuotaInt(value int) int {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
return geminiModelFlash
}
return geminiModelPro
}
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
var totals GeminiUsageTotals
for _, stat := range stats {
switch geminiModelClassFromName(stat.Model) {
case geminiModelFlash:
totals.FlashRequests += stat.Requests
totals.FlashTokens += stat.TotalTokens
totals.FlashCost += stat.ActualCost
default:
totals.ProRequests += stat.Requests
totals.ProTokens += stat.TotalTokens
totals.ProCost += stat.ActualCost
}
}
return totals
}
func geminiQuotaLocation() *time.Location {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
return time.FixedZone("PST", -8*3600)
}
return loc
}
func geminiDailyWindowStart(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
}
func geminiDailyResetTime(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
reset := start.Add(24 * time.Hour)
if !reset.After(localNow) {
reset = reset.Add(24 * time.Hour)
}
return reset
}

View File

@@ -112,17 +112,21 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
} }
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil { if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil return accessToken, nil
} }
detected = strings.TrimSpace(detected) detected = strings.TrimSpace(detected)
tierID = strings.TrimSpace(tierID)
if detected != "" { if detected != "" {
if account.Credentials == nil { if account.Credentials == nil {
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials["project_id"] = detected account.Credentials["project_id"] = detected
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account) _ = p.accountRepo.Update(ctx, account)
} }
} }

View File

@@ -13,6 +13,7 @@ import (
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -80,6 +81,7 @@ type OpenAIGatewayService struct {
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache GatewayCache cache GatewayCache
cfg *config.Config cfg *config.Config
concurrencyService *ConcurrencyService
billingService *BillingService billingService *BillingService
rateLimitService *RateLimitService rateLimitService *RateLimitService
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
@@ -95,6 +97,7 @@ func NewOpenAIGatewayService(
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService, billingService *BillingService,
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
@@ -108,6 +111,7 @@ func NewOpenAIGatewayService(
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService, billingService: billingService,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
@@ -126,6 +130,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }
// BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
}
// SelectAccount selects an OpenAI account with sticky session support // SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
@@ -218,6 +230,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
return selected, nil return selected, nil
} }
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: Sticky session ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: Load-aware selection ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, false)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: Fallback wait ============
sortAccountsByPriorityAndLastUsed(candidates, false)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
return accounts, nil
}
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
// GetAccessToken gets the access token for an OpenAI account // GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {

View File

@@ -5,6 +5,8 @@ import (
"log" "log"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
@@ -12,15 +14,30 @@ import (
// RateLimitService 处理限流和过载状态管理 // RateLimitService 处理限流和过载状态管理
type RateLimitService struct { type RateLimitService struct {
accountRepo AccountRepository accountRepo AccountRepository
cfg *config.Config usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
} }
type geminiUsageCacheEntry struct {
windowStart time.Time
cachedAt time.Time
totals GeminiUsageTotals
}
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例 // NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService { func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
return &RateLimitService{ return &RateLimitService{
accountRepo: accountRepo, accountRepo: accountRepo,
cfg: cfg, usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
usageCache: make(map[int64]*geminiUsageCacheEntry),
} }
} }
@@ -62,6 +79,106 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
} }
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
return true, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
return true, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return true, nil
}
var limit int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now()
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
return false, nil
}
return true, nil
}
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.RLock()
defer s.usageCacheMu.RUnlock()
if s.usageCache == nil {
return GeminiUsageTotals{}, false
}
entry, ok := s.usageCache[accountID]
if !ok || entry == nil {
return GeminiUsageTotals{}, false
}
if !entry.windowStart.Equal(windowStart) {
return GeminiUsageTotals{}, false
}
if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL {
return GeminiUsageTotals{}, false
}
return entry.totals, true
}
func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) {
s.usageCacheMu.Lock()
defer s.usageCacheMu.Unlock()
if s.usageCache == nil {
s.usageCache = make(map[int64]*geminiUsageCacheEntry)
}
s.usageCache[accountID] = &geminiUsageCacheEntry{
windowStart: windowStart,
cachedAt: now,
totals: totals,
}
}
// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier.
func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration {
if account == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
}
// handleAuthError 处理认证类错误(401/403),停止账号调度 // handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {

View File

@@ -10,7 +10,6 @@ type User struct {
ID int64 ID int64
Email string Email string
Username string Username string
Wechat string
Notes string Notes string
PasswordHash string PasswordHash string
Role string Role string

View File

@@ -0,0 +1,125 @@
package service
import (
"context"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// Error definitions for user attribute operations
var (
ErrAttributeDefinitionNotFound = infraerrors.NotFound("ATTRIBUTE_DEFINITION_NOT_FOUND", "attribute definition not found")
ErrAttributeKeyExists = infraerrors.Conflict("ATTRIBUTE_KEY_EXISTS", "attribute key already exists")
ErrInvalidAttributeType = infraerrors.BadRequest("INVALID_ATTRIBUTE_TYPE", "invalid attribute type")
ErrAttributeValidationFailed = infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", "attribute value validation failed")
)
// UserAttributeType represents supported attribute types
type UserAttributeType string
const (
AttributeTypeText UserAttributeType = "text"
AttributeTypeTextarea UserAttributeType = "textarea"
AttributeTypeNumber UserAttributeType = "number"
AttributeTypeEmail UserAttributeType = "email"
AttributeTypeURL UserAttributeType = "url"
AttributeTypeDate UserAttributeType = "date"
AttributeTypeSelect UserAttributeType = "select"
AttributeTypeMultiSelect UserAttributeType = "multi_select"
)
// UserAttributeOption represents a select option for select/multi_select types
type UserAttributeOption struct {
Value string `json:"value"`
Label string `json:"label"`
}
// UserAttributeValidation represents validation rules for an attribute
type UserAttributeValidation struct {
MinLength *int `json:"min_length,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
Min *int `json:"min,omitempty"`
Max *int `json:"max,omitempty"`
Pattern *string `json:"pattern,omitempty"`
Message *string `json:"message,omitempty"`
}
// UserAttributeDefinition represents a custom attribute definition
type UserAttributeDefinition struct {
ID int64
Key string
Name string
Description string
Type UserAttributeType
Options []UserAttributeOption
Required bool
Validation UserAttributeValidation
Placeholder string
DisplayOrder int
Enabled bool
CreatedAt time.Time
UpdatedAt time.Time
}
// UserAttributeValue represents a user's attribute value
type UserAttributeValue struct {
ID int64
UserID int64
AttributeID int64
Value string
CreatedAt time.Time
UpdatedAt time.Time
}
// CreateAttributeDefinitionInput for creating new definition
type CreateAttributeDefinitionInput struct {
Key string
Name string
Description string
Type UserAttributeType
Options []UserAttributeOption
Required bool
Validation UserAttributeValidation
Placeholder string
Enabled bool
}
// UpdateAttributeDefinitionInput for updating definition
type UpdateAttributeDefinitionInput struct {
Name *string
Description *string
Type *UserAttributeType
Options *[]UserAttributeOption
Required *bool
Validation *UserAttributeValidation
Placeholder *string
Enabled *bool
}
// UpdateUserAttributeInput for updating a single attribute value
type UpdateUserAttributeInput struct {
AttributeID int64
Value string
}
// UserAttributeDefinitionRepository interface for attribute definition persistence
type UserAttributeDefinitionRepository interface {
Create(ctx context.Context, def *UserAttributeDefinition) error
GetByID(ctx context.Context, id int64) (*UserAttributeDefinition, error)
GetByKey(ctx context.Context, key string) (*UserAttributeDefinition, error)
Update(ctx context.Context, def *UserAttributeDefinition) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error)
UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error
ExistsByKey(ctx context.Context, key string) (bool, error)
}
// UserAttributeValueRepository interface for user attribute value persistence
type UserAttributeValueRepository interface {
GetByUserID(ctx context.Context, userID int64) ([]UserAttributeValue, error)
GetByUserIDs(ctx context.Context, userIDs []int64) ([]UserAttributeValue, error)
UpsertBatch(ctx context.Context, userID int64, values []UpdateUserAttributeInput) error
DeleteByAttributeID(ctx context.Context, attributeID int64) error
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,295 @@
package service
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// UserAttributeService handles attribute management
type UserAttributeService struct {
defRepo UserAttributeDefinitionRepository
valueRepo UserAttributeValueRepository
}
// NewUserAttributeService creates a new service instance
func NewUserAttributeService(
defRepo UserAttributeDefinitionRepository,
valueRepo UserAttributeValueRepository,
) *UserAttributeService {
return &UserAttributeService{
defRepo: defRepo,
valueRepo: valueRepo,
}
}
// CreateDefinition creates a new attribute definition
func (s *UserAttributeService) CreateDefinition(ctx context.Context, input CreateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
// Validate type
if !isValidAttributeType(input.Type) {
return nil, ErrInvalidAttributeType
}
// Check if key exists
exists, err := s.defRepo.ExistsByKey(ctx, input.Key)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
return nil, ErrAttributeKeyExists
}
def := &UserAttributeDefinition{
Key: input.Key,
Name: input.Name,
Description: input.Description,
Type: input.Type,
Options: input.Options,
Required: input.Required,
Validation: input.Validation,
Placeholder: input.Placeholder,
Enabled: input.Enabled,
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
return def, nil
}
// GetDefinition retrieves a definition by ID
func (s *UserAttributeService) GetDefinition(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
return s.defRepo.GetByID(ctx, id)
}
// ListDefinitions lists all definitions
func (s *UserAttributeService) ListDefinitions(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error) {
return s.defRepo.List(ctx, enabledOnly)
}
// UpdateDefinition updates an existing definition
func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, input UpdateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
def, err := s.defRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != nil {
def.Name = *input.Name
}
if input.Description != nil {
def.Description = *input.Description
}
if input.Type != nil {
if !isValidAttributeType(*input.Type) {
return nil, ErrInvalidAttributeType
}
def.Type = *input.Type
}
if input.Options != nil {
def.Options = *input.Options
}
if input.Required != nil {
def.Required = *input.Required
}
if input.Validation != nil {
def.Validation = *input.Validation
}
if input.Placeholder != nil {
def.Placeholder = *input.Placeholder
}
if input.Enabled != nil {
def.Enabled = *input.Enabled
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
return def, nil
}
// DeleteDefinition soft-deletes a definition and hard-deletes associated values
func (s *UserAttributeService) DeleteDefinition(ctx context.Context, id int64) error {
// Check if definition exists
_, err := s.defRepo.GetByID(ctx, id)
if err != nil {
return err
}
// First delete all values (hard delete)
if err := s.valueRepo.DeleteByAttributeID(ctx, id); err != nil {
return fmt.Errorf("delete values: %w", err)
}
// Then soft-delete the definition
if err := s.defRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete definition: %w", err)
}
return nil
}
// ReorderDefinitions updates display order for multiple definitions
func (s *UserAttributeService) ReorderDefinitions(ctx context.Context, orders map[int64]int) error {
return s.defRepo.UpdateDisplayOrders(ctx, orders)
}
// GetUserAttributes retrieves all attribute values for a user
func (s *UserAttributeService) GetUserAttributes(ctx context.Context, userID int64) ([]UserAttributeValue, error) {
return s.valueRepo.GetByUserID(ctx, userID)
}
// GetBatchUserAttributes retrieves attribute values for multiple users
// Returns a map of userID -> map of attributeID -> value
func (s *UserAttributeService) GetBatchUserAttributes(ctx context.Context, userIDs []int64) (map[int64]map[int64]string, error) {
values, err := s.valueRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
return nil, err
}
result := make(map[int64]map[int64]string)
for _, v := range values {
if result[v.UserID] == nil {
result[v.UserID] = make(map[int64]string)
}
result[v.UserID][v.AttributeID] = v.Value
}
return result, nil
}
// UpdateUserAttributes batch updates attribute values for a user
func (s *UserAttributeService) UpdateUserAttributes(ctx context.Context, userID int64, inputs []UpdateUserAttributeInput) error {
// Validate all values before updating
defs, err := s.defRepo.List(ctx, true)
if err != nil {
return fmt.Errorf("list definitions: %w", err)
}
defMap := make(map[int64]*UserAttributeDefinition, len(defs))
for i := range defs {
defMap[defs[i].ID] = &defs[i]
}
for _, input := range inputs {
def, ok := defMap[input.AttributeID]
if !ok {
return ErrAttributeDefinitionNotFound
}
if err := s.validateValue(def, input.Value); err != nil {
return err
}
}
return s.valueRepo.UpsertBatch(ctx, userID, inputs)
}
// validateValue validates a value against its definition
func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value string) error {
// Skip validation for empty non-required fields
if value == "" && !def.Required {
return nil
}
// Required check
if def.Required && value == "" {
return validationError(fmt.Sprintf("%s is required", def.Name))
}
v := def.Validation
// String length validation
if v.MinLength != nil && len(value) < *v.MinLength {
return validationError(fmt.Sprintf("%s must be at least %d characters", def.Name, *v.MinLength))
}
if v.MaxLength != nil && len(value) > *v.MaxLength {
return validationError(fmt.Sprintf("%s must be at most %d characters", def.Name, *v.MaxLength))
}
// Number validation
if def.Type == AttributeTypeNumber && value != "" {
num, err := strconv.Atoi(value)
if err != nil {
return validationError(fmt.Sprintf("%s must be a number", def.Name))
}
if v.Min != nil && num < *v.Min {
return validationError(fmt.Sprintf("%s must be at least %d", def.Name, *v.Min))
}
if v.Max != nil && num > *v.Max {
return validationError(fmt.Sprintf("%s must be at most %d", def.Name, *v.Max))
}
}
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
}
return validationError(msg)
}
}
// Select validation
if def.Type == AttributeTypeSelect && value != "" {
found := false
for _, opt := range def.Options {
if opt.Value == value {
found = true
break
}
}
if !found {
return validationError(fmt.Sprintf("%s: invalid option", def.Name))
}
}
// Multi-select validation (stored as JSON array)
if def.Type == AttributeTypeMultiSelect && value != "" {
var values []string
if err := json.Unmarshal([]byte(value), &values); err != nil {
// Try comma-separated fallback
values = strings.Split(value, ",")
}
for _, val := range values {
val = strings.TrimSpace(val)
found := false
for _, opt := range def.Options {
if opt.Value == val {
found = true
break
}
}
if !found {
return validationError(fmt.Sprintf("%s: invalid option %s", def.Name, val))
}
}
}
return nil
}
// validationError creates a validation error with a custom message
func validationError(msg string) error {
return infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", msg)
}
func isValidAttributeType(t UserAttributeType) bool {
switch t {
case AttributeTypeText, AttributeTypeTextarea, AttributeTypeNumber,
AttributeTypeEmail, AttributeTypeURL, AttributeTypeDate,
AttributeTypeSelect, AttributeTypeMultiSelect:
return true
}
return false
}

View File

@@ -14,6 +14,14 @@ var (
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
) )
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
Status string // User status filter
Role string // User role filter
Search string // Search in email, username
Attributes map[int64]string // Custom attribute filters: attributeID -> value
}
type UserRepository interface { type UserRepository interface {
Create(ctx context.Context, user *User) error Create(ctx context.Context, user *User) error
GetByID(ctx context.Context, id int64) (*User, error) GetByID(ctx context.Context, id int64) (*User, error)
@@ -23,7 +31,7 @@ type UserRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -36,7 +44,6 @@ type UserRepository interface {
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Email *string `json:"email"` Email *string `json:"email"`
Username *string `json:"username"` Username *string `json:"username"`
Wechat *string `json:"wechat"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
} }
@@ -100,10 +107,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username user.Username = *req.Username
} }
if req.Wechat != nil {
user.Wechat = *req.Wechat
}
if req.Concurrency != nil { if req.Concurrency != nil {
user.Concurrency = *req.Concurrency user.Concurrency = *req.Concurrency
} }

View File

@@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc return svc
} }
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}
return svc
}
// ProviderSet is the Wire provider set for all services // ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
// Core services // Core services
@@ -94,6 +103,7 @@ var ProviderSet = wire.NewSet(
NewOAuthService, NewOAuthService,
NewOpenAIOAuthService, NewOpenAIOAuthService,
NewGeminiOAuthService, NewGeminiOAuthService,
NewGeminiQuotaService,
NewAntigravityOAuthService, NewAntigravityOAuthService,
NewGeminiTokenProvider, NewGeminiTokenProvider,
NewGeminiMessagesCompatService, NewGeminiMessagesCompatService,
@@ -107,7 +117,7 @@ var ProviderSet = wire.NewSet(
ProvideEmailQueueService, ProvideEmailQueueService,
NewTurnstileService, NewTurnstileService,
NewSubscriptionService, NewSubscriptionService,
NewConcurrencyService, ProvideConcurrencyService,
NewIdentityService, NewIdentityService,
NewCRSSyncService, NewCRSSyncService,
ProvideUpdateService, ProvideUpdateService,
@@ -115,4 +125,5 @@ var ProviderSet = wire.NewSet(
ProvideTimingWheelService, ProvideTimingWheelService,
ProvideDeferredService, ProvideDeferredService,
ProvideAntigravityQuotaRefresher, ProvideAntigravityQuotaRefresher,
NewUserAttributeService,
) )

View File

@@ -0,0 +1,48 @@
-- Add user attribute definitions and values tables for custom user attributes.
-- User Attribute Definitions table (with soft delete support)
CREATE TABLE IF NOT EXISTS user_attribute_definitions (
id BIGSERIAL PRIMARY KEY,
key VARCHAR(100) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT DEFAULT '',
type VARCHAR(20) NOT NULL,
options JSONB DEFAULT '[]'::jsonb,
required BOOLEAN NOT NULL DEFAULT FALSE,
validation JSONB DEFAULT '{}'::jsonb,
placeholder VARCHAR(255) DEFAULT '',
display_order INT NOT NULL DEFAULT 0,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
-- Partial unique index for key (only for non-deleted records)
-- Allows reusing keys after soft delete
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_attribute_definitions_key_unique
ON user_attribute_definitions(key) WHERE deleted_at IS NULL;
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_enabled
ON user_attribute_definitions(enabled);
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_display_order
ON user_attribute_definitions(display_order);
CREATE INDEX IF NOT EXISTS idx_user_attribute_definitions_deleted_at
ON user_attribute_definitions(deleted_at);
-- User Attribute Values table (hard delete only, no deleted_at)
CREATE TABLE IF NOT EXISTS user_attribute_values (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
attribute_id BIGINT NOT NULL REFERENCES user_attribute_definitions(id) ON DELETE CASCADE,
value TEXT DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(user_id, attribute_id)
);
CREATE INDEX IF NOT EXISTS idx_user_attribute_values_user_id
ON user_attribute_values(user_id);
CREATE INDEX IF NOT EXISTS idx_user_attribute_values_attribute_id
ON user_attribute_values(attribute_id);

View File

@@ -0,0 +1,83 @@
-- Migration: Move wechat field from users table to user_attribute_values
-- This migration:
-- 1. Creates a "wechat" attribute definition
-- 2. Migrates existing wechat data to user_attribute_values
-- 3. Does NOT drop the wechat column (for rollback safety, can be done in a later migration)
-- +goose Up
-- +goose StatementBegin
-- Step 1: Insert wechat attribute definition if not exists
INSERT INTO user_attribute_definitions (key, name, description, type, options, required, validation, placeholder, display_order, enabled, created_at, updated_at)
SELECT 'wechat', '微信', '用户微信号', 'text', '[]'::jsonb, false, '{}'::jsonb, '请输入微信号', 0, true, NOW(), NOW()
WHERE NOT EXISTS (
SELECT 1 FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL
);
-- Step 2: Migrate existing wechat values to user_attribute_values
-- Only migrate non-empty values
INSERT INTO user_attribute_values (user_id, attribute_id, value, created_at, updated_at)
SELECT
u.id,
(SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL LIMIT 1),
u.wechat,
NOW(),
NOW()
FROM users u
WHERE u.wechat IS NOT NULL
AND u.wechat != ''
AND u.deleted_at IS NULL
AND NOT EXISTS (
SELECT 1 FROM user_attribute_values uav
WHERE uav.user_id = u.id
AND uav.attribute_id = (SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL LIMIT 1)
);
-- Step 3: Update display_order to ensure wechat appears first
UPDATE user_attribute_definitions
SET display_order = -1
WHERE key = 'wechat' AND deleted_at IS NULL;
-- Reorder all attributes starting from 0
WITH ordered AS (
SELECT id, ROW_NUMBER() OVER (ORDER BY display_order, id) - 1 as new_order
FROM user_attribute_definitions
WHERE deleted_at IS NULL
)
UPDATE user_attribute_definitions
SET display_order = ordered.new_order
FROM ordered
WHERE user_attribute_definitions.id = ordered.id;
-- Step 4: Drop the redundant wechat column from users table
ALTER TABLE users DROP COLUMN IF EXISTS wechat;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Restore wechat column
ALTER TABLE users ADD COLUMN IF NOT EXISTS wechat VARCHAR(100) DEFAULT '';
-- Copy attribute values back to users.wechat column
UPDATE users u
SET wechat = uav.value
FROM user_attribute_values uav
JOIN user_attribute_definitions uad ON uav.attribute_id = uad.id
WHERE uav.user_id = u.id
AND uad.key = 'wechat'
AND uad.deleted_at IS NULL;
-- Delete migrated attribute values
DELETE FROM user_attribute_values
WHERE attribute_id IN (
SELECT id FROM user_attribute_definitions WHERE key = 'wechat' AND deleted_at IS NULL
);
-- Soft-delete the wechat attribute definition
UPDATE user_attribute_definitions
SET deleted_at = NOW()
WHERE key = 'wechat' AND deleted_at IS NULL;
-- +goose StatementEnd

Some files were not shown because too many files have changed in this diff Show More