Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
This commit is contained in:
Edric Li
2026-01-01 16:15:16 +08:00
215 changed files with 22998 additions and 1641 deletions

View File

@@ -129,6 +129,10 @@ jobs:
echo "$TAG_MESSAGE" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Set lowercase owner for GHCR
id: lowercase
run: echo "owner=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_OUTPUT
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
@@ -138,6 +142,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_OWNER_LOWER: ${{ steps.lowercase.outputs.owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -168,7 +173,7 @@ jobs:
VERSION=${TAG_NAME#v}
REPO="${{ github.repository }}"
DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api"
GHCR_IMAGE="ghcr.io/${REPO}"
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
# 获取 tag message 内容
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'

2
.gitignore vendored
View File

@@ -77,6 +77,7 @@ temp/
*.temp
*.log
*.bak
.cache/
# ===================
# 构建产物
@@ -116,3 +117,4 @@ openspec/
docs/
code-reviews/
AGENTS.md
backend/cmd/server/server

View File

@@ -78,12 +78,12 @@ dockers:
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# GHCR images
# GHCR images (owner must be lowercase)
- id: ghcr-amd64
goos: linux
goarch: amd64
image_templates:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
@@ -96,7 +96,7 @@ dockers:
goos: linux
goarch: arm64
image_templates:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
@@ -127,26 +127,26 @@ docker_manifests:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
# GHCR manifests
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}"
# GHCR manifests (owner must be lowercase)
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}"
image_templates:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:latest"
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
image_templates:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}.{{ .Minor }}"
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}"
- name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Major }}"
image_templates:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
release:
github:
@@ -173,7 +173,7 @@ release:
docker pull {{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}
# GitHub Container Registry
docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}
docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}
```
**One-line install (Linux):**

View File

@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build(
// Infrastructure layer ProviderSets
config.ProviderSet,
infrastructure.ProviderSet,
// Business layer ProviderSets
repository.ProviderSet,
@@ -67,6 +65,7 @@ func provideCleanup(
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
@@ -94,6 +93,10 @@ func provideCleanup(
emailQueue.Stop()
return nil
}},
{"BillingCacheService", func() error {
billingCache.Stop()
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil

View File

@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
if err != nil {
return nil, err
}
client, err := infrastructure.ProvideEnt(configConfig)
client, err := repository.ProvideEnt(configConfig)
if err != nil {
return nil, err
}
db, err := infrastructure.ProvideSQLDB(client)
db, err := repository.ProvideSQLDB(client)
if err != nil {
return nil, err
}
userRepository := repository.NewUserRepository(client, db)
settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := infrastructure.ProvideRedis(configConfig)
redisClient := repository.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -70,7 +69,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository)
@@ -88,9 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
@@ -99,8 +99,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(redisClient)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
@@ -109,7 +109,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.NewGitHubReleaseClient()
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
@@ -128,10 +128,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
@@ -142,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -170,6 +170,7 @@ func provideCleanup(
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
@@ -196,6 +197,10 @@ func provideCleanup(
emailQueue.Stop()
return nil
}},
{"BillingCacheService", func() error {
billingCache.Stop()
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/proxy"
)
// Account is the model entity for the Account schema.
@@ -70,11 +71,15 @@ type Account struct {
type AccountEdges struct {
// Groups holds the value of the groups edge.
Groups []*Group `json:"groups,omitempty"`
// Proxy holds the value of the proxy edge.
Proxy *Proxy `json:"proxy,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// AccountGroups holds the value of the account_groups edge.
AccountGroups []*AccountGroup `json:"account_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
loadedTypes [4]bool
}
// GroupsOrErr returns the Groups value or an error if the edge
@@ -86,10 +91,30 @@ func (e AccountEdges) GroupsOrErr() ([]*Group, error) {
return nil, &NotLoadedError{edge: "groups"}
}
// ProxyOrErr returns the Proxy value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AccountEdges) ProxyOrErr() (*Proxy, error) {
if e.Proxy != nil {
return e.Proxy, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: proxy.Label}
}
return nil, &NotLoadedError{edge: "proxy"}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e AccountEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[2] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// AccountGroupsOrErr returns the AccountGroups value or an error if the edge
// was not loaded in eager-loading.
func (e AccountEdges) AccountGroupsOrErr() ([]*AccountGroup, error) {
if e.loadedTypes[1] {
if e.loadedTypes[3] {
return e.AccountGroups, nil
}
return nil, &NotLoadedError{edge: "account_groups"}
@@ -289,6 +314,16 @@ func (_m *Account) QueryGroups() *GroupQuery {
return NewAccountClient(_m.config).QueryGroups(_m)
}
// QueryProxy queries the "proxy" edge of the Account entity.
func (_m *Account) QueryProxy() *ProxyQuery {
return NewAccountClient(_m.config).QueryProxy(_m)
}
// QueryUsageLogs queries the "usage_logs" edge of the Account entity.
func (_m *Account) QueryUsageLogs() *UsageLogQuery {
return NewAccountClient(_m.config).QueryUsageLogs(_m)
}
// QueryAccountGroups queries the "account_groups" edge of the Account entity.
func (_m *Account) QueryAccountGroups() *AccountGroupQuery {
return NewAccountClient(_m.config).QueryAccountGroups(_m)

View File

@@ -59,6 +59,10 @@ const (
FieldSessionWindowStatus = "session_window_status"
// EdgeGroups holds the string denoting the groups edge name in mutations.
EdgeGroups = "groups"
// EdgeProxy holds the string denoting the proxy edge name in mutations.
EdgeProxy = "proxy"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// EdgeAccountGroups holds the string denoting the account_groups edge name in mutations.
EdgeAccountGroups = "account_groups"
// Table holds the table name of the account in the database.
@@ -68,6 +72,20 @@ const (
// GroupsInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package.
GroupsInverseTable = "groups"
// ProxyTable is the table that holds the proxy relation/edge.
ProxyTable = "accounts"
// ProxyInverseTable is the table name for the Proxy entity.
// It exists in this package in order to avoid circular dependency with the "proxy" package.
ProxyInverseTable = "proxies"
// ProxyColumn is the table column denoting the proxy relation/edge.
ProxyColumn = "proxy_id"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "account_id"
// AccountGroupsTable is the table that holds the account_groups relation/edge.
AccountGroupsTable = "account_groups"
// AccountGroupsInverseTable is the table name for the AccountGroup entity.
@@ -274,6 +292,27 @@ func ByGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByProxyField orders the results by proxy field.
func ByProxyField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newProxyStep(), sql.OrderByField(field, opts...))
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByAccountGroupsCount orders the results by account_groups count.
func ByAccountGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -294,6 +333,20 @@ func newGroupsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2M, false, GroupsTable, GroupsPrimaryKey...),
)
}
func newProxyStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(ProxyInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, false, ProxyTable, ProxyColumn),
)
}
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
func newAccountGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@@ -495,26 +495,6 @@ func ProxyIDNotIn(vs ...int64) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldProxyID, vs...))
}
// ProxyIDGT applies the GT predicate on the "proxy_id" field.
func ProxyIDGT(v int64) predicate.Account {
return predicate.Account(sql.FieldGT(FieldProxyID, v))
}
// ProxyIDGTE applies the GTE predicate on the "proxy_id" field.
func ProxyIDGTE(v int64) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldProxyID, v))
}
// ProxyIDLT applies the LT predicate on the "proxy_id" field.
func ProxyIDLT(v int64) predicate.Account {
return predicate.Account(sql.FieldLT(FieldProxyID, v))
}
// ProxyIDLTE applies the LTE predicate on the "proxy_id" field.
func ProxyIDLTE(v int64) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldProxyID, v))
}
// ProxyIDIsNil applies the IsNil predicate on the "proxy_id" field.
func ProxyIDIsNil() predicate.Account {
return predicate.Account(sql.FieldIsNull(FieldProxyID))
@@ -1153,6 +1133,52 @@ func HasGroupsWith(preds ...predicate.Group) predicate.Account {
})
}
// HasProxy applies the HasEdge predicate on the "proxy" edge.
func HasProxy() predicate.Account {
return predicate.Account(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, false, ProxyTable, ProxyColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasProxyWith applies the HasEdge predicate on the "proxy" edge with a given conditions (other predicates).
func HasProxyWith(preds ...predicate.Proxy) predicate.Account {
return predicate.Account(func(s *sql.Selector) {
step := newProxyStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.Account {
return predicate.Account(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.Account {
return predicate.Account(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasAccountGroups applies the HasEdge predicate on the "account_groups" edge.
func HasAccountGroups() predicate.Account {
return predicate.Account(func(s *sql.Selector) {

View File

@@ -13,6 +13,8 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
)
// AccountCreate is the builder for creating a Account entity.
@@ -292,6 +294,26 @@ func (_c *AccountCreate) AddGroups(v ...*Group) *AccountCreate {
return _c.AddGroupIDs(ids...)
}
// SetProxy sets the "proxy" edge to the Proxy entity.
func (_c *AccountCreate) SetProxy(v *Proxy) *AccountCreate {
return _c.SetProxyID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *AccountCreate) AddUsageLogIDs(ids ...int64) *AccountCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *AccountCreate) AddUsageLogs(v ...*UsageLog) *AccountCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the AccountMutation object of the builder.
func (_c *AccountCreate) Mutation() *AccountMutation {
return _c.mutation
@@ -495,10 +517,6 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldExtra, field.TypeJSON, value)
_node.Extra = value
}
if value, ok := _c.mutation.ProxyID(); ok {
_spec.SetField(account.FieldProxyID, field.TypeInt64, value)
_node.ProxyID = &value
}
if value, ok := _c.mutation.Concurrency(); ok {
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
_node.Concurrency = value
@@ -567,6 +585,39 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
edge.Target.Fields = specE.Fields
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.ProxyIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: false,
Table: account.ProxyTable,
Columns: []string{account.ProxyColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.ProxyID = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
@@ -721,12 +772,6 @@ func (u *AccountUpsert) UpdateProxyID() *AccountUpsert {
return u
}
// AddProxyID adds v to the "proxy_id" field.
func (u *AccountUpsert) AddProxyID(v int64) *AccountUpsert {
u.Add(account.FieldProxyID, v)
return u
}
// ClearProxyID clears the value of the "proxy_id" field.
func (u *AccountUpsert) ClearProxyID() *AccountUpsert {
u.SetNull(account.FieldProxyID)
@@ -1094,13 +1139,6 @@ func (u *AccountUpsertOne) SetProxyID(v int64) *AccountUpsertOne {
})
}
// AddProxyID adds v to the "proxy_id" field.
func (u *AccountUpsertOne) AddProxyID(v int64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddProxyID(v)
})
}
// UpdateProxyID sets the "proxy_id" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateProxyID() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -1676,13 +1714,6 @@ func (u *AccountUpsertBulk) SetProxyID(v int64) *AccountUpsertBulk {
})
}
// AddProxyID adds v to the "proxy_id" field.
func (u *AccountUpsertBulk) AddProxyID(v int64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddProxyID(v)
})
}
// UpdateProxyID sets the "proxy_id" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateProxyID() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -16,6 +16,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
)
// AccountQuery is the builder for querying Account entities.
@@ -26,6 +28,8 @@ type AccountQuery struct {
inters []Interceptor
predicates []predicate.Account
withGroups *GroupQuery
withProxy *ProxyQuery
withUsageLogs *UsageLogQuery
withAccountGroups *AccountGroupQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
@@ -85,6 +89,50 @@ func (_q *AccountQuery) QueryGroups() *GroupQuery {
return query
}
// QueryProxy chains the current query on the "proxy" edge.
func (_q *AccountQuery) QueryProxy() *ProxyQuery {
query := (&ProxyClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(account.Table, account.FieldID, selector),
sqlgraph.To(proxy.Table, proxy.FieldID),
sqlgraph.Edge(sqlgraph.M2O, false, account.ProxyTable, account.ProxyColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *AccountQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(account.Table, account.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, account.UsageLogsTable, account.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAccountGroups chains the current query on the "account_groups" edge.
func (_q *AccountQuery) QueryAccountGroups() *AccountGroupQuery {
query := (&AccountGroupClient{config: _q.config}).Query()
@@ -300,6 +348,8 @@ func (_q *AccountQuery) Clone() *AccountQuery {
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.Account{}, _q.predicates...),
withGroups: _q.withGroups.Clone(),
withProxy: _q.withProxy.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAccountGroups: _q.withAccountGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -318,6 +368,28 @@ func (_q *AccountQuery) WithGroups(opts ...func(*GroupQuery)) *AccountQuery {
return _q
}
// WithProxy tells the query-builder to eager-load the nodes that are connected to
// the "proxy" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AccountQuery) WithProxy(opts ...func(*ProxyQuery)) *AccountQuery {
query := (&ProxyClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withProxy = query
return _q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AccountQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *AccountQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// WithAccountGroups tells the query-builder to eager-load the nodes that are connected to
// the "account_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AccountQuery) WithAccountGroups(opts ...func(*AccountGroupQuery)) *AccountQuery {
@@ -407,8 +479,10 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco
var (
nodes = []*Account{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
loadedTypes = [4]bool{
_q.withGroups != nil,
_q.withProxy != nil,
_q.withUsageLogs != nil,
_q.withAccountGroups != nil,
}
)
@@ -437,6 +511,19 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco
return nil, err
}
}
if query := _q.withProxy; query != nil {
if err := _q.loadProxy(ctx, query, nodes, nil,
func(n *Account, e *Proxy) { n.Edges.Proxy = e }); err != nil {
return nil, err
}
}
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *Account) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *Account, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
if query := _q.withAccountGroups; query != nil {
if err := _q.loadAccountGroups(ctx, query, nodes,
func(n *Account) { n.Edges.AccountGroups = []*AccountGroup{} },
@@ -508,6 +595,68 @@ func (_q *AccountQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes
}
return nil
}
func (_q *AccountQuery) loadProxy(ctx context.Context, query *ProxyQuery, nodes []*Account, init func(*Account), assign func(*Account, *Proxy)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*Account)
for i := range nodes {
if nodes[i].ProxyID == nil {
continue
}
fk := *nodes[i].ProxyID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(proxy.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "proxy_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AccountQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*Account, init func(*Account), assign func(*Account, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*Account)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldAccountID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(account.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.AccountID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "account_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *AccountQuery) loadAccountGroups(ctx context.Context, query *AccountGroupQuery, nodes []*Account, init func(*Account), assign func(*Account, *AccountGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*Account)
@@ -564,6 +713,9 @@ func (_q *AccountQuery) querySpec() *sqlgraph.QuerySpec {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withProxy != nil {
_spec.Node.AddColumnOnce(account.FieldProxyID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {

View File

@@ -14,6 +14,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
)
// AccountUpdate is the builder for updating Account entities.
@@ -111,7 +113,6 @@ func (_u *AccountUpdate) SetExtra(v map[string]interface{}) *AccountUpdate {
// SetProxyID sets the "proxy_id" field.
func (_u *AccountUpdate) SetProxyID(v int64) *AccountUpdate {
_u.mutation.ResetProxyID()
_u.mutation.SetProxyID(v)
return _u
}
@@ -124,12 +125,6 @@ func (_u *AccountUpdate) SetNillableProxyID(v *int64) *AccountUpdate {
return _u
}
// AddProxyID adds value to the "proxy_id" field.
func (_u *AccountUpdate) AddProxyID(v int64) *AccountUpdate {
_u.mutation.AddProxyID(v)
return _u
}
// ClearProxyID clears the value of the "proxy_id" field.
func (_u *AccountUpdate) ClearProxyID() *AccountUpdate {
_u.mutation.ClearProxyID()
@@ -381,6 +376,26 @@ func (_u *AccountUpdate) AddGroups(v ...*Group) *AccountUpdate {
return _u.AddGroupIDs(ids...)
}
// SetProxy sets the "proxy" edge to the Proxy entity.
func (_u *AccountUpdate) SetProxy(v *Proxy) *AccountUpdate {
return _u.SetProxyID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *AccountUpdate) AddUsageLogIDs(ids ...int64) *AccountUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *AccountUpdate) AddUsageLogs(v ...*UsageLog) *AccountUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the AccountMutation object of the builder.
func (_u *AccountUpdate) Mutation() *AccountMutation {
return _u.mutation
@@ -407,6 +422,33 @@ func (_u *AccountUpdate) RemoveGroups(v ...*Group) *AccountUpdate {
return _u.RemoveGroupIDs(ids...)
}
// ClearProxy clears the "proxy" edge to the Proxy entity.
func (_u *AccountUpdate) ClearProxy() *AccountUpdate {
_u.mutation.ClearProxy()
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *AccountUpdate) ClearUsageLogs() *AccountUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *AccountUpdate) RemoveUsageLogIDs(ids ...int64) *AccountUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *AccountUpdate) RemoveUsageLogs(v ...*UsageLog) *AccountUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AccountUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -515,15 +557,6 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Extra(); ok {
_spec.SetField(account.FieldExtra, field.TypeJSON, value)
}
if value, ok := _u.mutation.ProxyID(); ok {
_spec.SetField(account.FieldProxyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedProxyID(); ok {
_spec.AddField(account.FieldProxyID, field.TypeInt64, value)
}
if _u.mutation.ProxyIDCleared() {
_spec.ClearField(account.FieldProxyID, field.TypeInt64)
}
if value, ok := _u.mutation.Concurrency(); ok {
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
}
@@ -647,6 +680,80 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
edge.Target.Fields = specE.Fields
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.ProxyCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: false,
Table: account.ProxyTable,
Columns: []string{account.ProxyColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ProxyIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: false,
Table: account.ProxyTable,
Columns: []string{account.ProxyColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{account.Label}
@@ -749,7 +856,6 @@ func (_u *AccountUpdateOne) SetExtra(v map[string]interface{}) *AccountUpdateOne
// SetProxyID sets the "proxy_id" field.
func (_u *AccountUpdateOne) SetProxyID(v int64) *AccountUpdateOne {
_u.mutation.ResetProxyID()
_u.mutation.SetProxyID(v)
return _u
}
@@ -762,12 +868,6 @@ func (_u *AccountUpdateOne) SetNillableProxyID(v *int64) *AccountUpdateOne {
return _u
}
// AddProxyID adds value to the "proxy_id" field.
func (_u *AccountUpdateOne) AddProxyID(v int64) *AccountUpdateOne {
_u.mutation.AddProxyID(v)
return _u
}
// ClearProxyID clears the value of the "proxy_id" field.
func (_u *AccountUpdateOne) ClearProxyID() *AccountUpdateOne {
_u.mutation.ClearProxyID()
@@ -1019,6 +1119,26 @@ func (_u *AccountUpdateOne) AddGroups(v ...*Group) *AccountUpdateOne {
return _u.AddGroupIDs(ids...)
}
// SetProxy sets the "proxy" edge to the Proxy entity.
func (_u *AccountUpdateOne) SetProxy(v *Proxy) *AccountUpdateOne {
return _u.SetProxyID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *AccountUpdateOne) AddUsageLogIDs(ids ...int64) *AccountUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *AccountUpdateOne) AddUsageLogs(v ...*UsageLog) *AccountUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the AccountMutation object of the builder.
func (_u *AccountUpdateOne) Mutation() *AccountMutation {
return _u.mutation
@@ -1045,6 +1165,33 @@ func (_u *AccountUpdateOne) RemoveGroups(v ...*Group) *AccountUpdateOne {
return _u.RemoveGroupIDs(ids...)
}
// ClearProxy clears the "proxy" edge to the Proxy entity.
func (_u *AccountUpdateOne) ClearProxy() *AccountUpdateOne {
_u.mutation.ClearProxy()
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *AccountUpdateOne) ClearUsageLogs() *AccountUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *AccountUpdateOne) RemoveUsageLogIDs(ids ...int64) *AccountUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *AccountUpdateOne) RemoveUsageLogs(v ...*UsageLog) *AccountUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the AccountUpdate builder.
func (_u *AccountUpdateOne) Where(ps ...predicate.Account) *AccountUpdateOne {
_u.mutation.Where(ps...)
@@ -1183,15 +1330,6 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.Extra(); ok {
_spec.SetField(account.FieldExtra, field.TypeJSON, value)
}
if value, ok := _u.mutation.ProxyID(); ok {
_spec.SetField(account.FieldProxyID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedProxyID(); ok {
_spec.AddField(account.FieldProxyID, field.TypeInt64, value)
}
if _u.mutation.ProxyIDCleared() {
_spec.ClearField(account.FieldProxyID, field.TypeInt64)
}
if value, ok := _u.mutation.Concurrency(); ok {
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
}
@@ -1315,6 +1453,80 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
edge.Target.Fields = specE.Fields
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.ProxyCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: false,
Table: account.ProxyTable,
Columns: []string{account.ProxyColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ProxyIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: false,
Table: account.ProxyTable,
Columns: []string{account.ProxyColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: account.UsageLogsTable,
Columns: []string{account.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &Account{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -47,9 +47,11 @@ type ApiKeyEdges struct {
User *User `json:"user,omitempty"`
// Group holds the value of the group edge.
Group *Group `json:"group,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
loadedTypes [3]bool
}
// UserOrErr returns the User value or an error if the edge
@@ -74,6 +76,15 @@ func (e ApiKeyEdges) GroupOrErr() (*Group, error) {
return nil, &NotLoadedError{edge: "group"}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e ApiKeyEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[2] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*ApiKey) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
@@ -179,6 +190,11 @@ func (_m *ApiKey) QueryGroup() *GroupQuery {
return NewApiKeyClient(_m.config).QueryGroup(_m)
}
// QueryUsageLogs queries the "usage_logs" edge of the ApiKey entity.
func (_m *ApiKey) QueryUsageLogs() *UsageLogQuery {
return NewApiKeyClient(_m.config).QueryUsageLogs(_m)
}
// Update returns a builder for updating this ApiKey.
// Note that you need to call ApiKey.Unwrap() before calling this method if this ApiKey
// was returned from a transaction, and the transaction was committed or rolled back.

View File

@@ -35,6 +35,8 @@ const (
EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations.
EdgeGroup = "group"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// Table holds the table name of the apikey in the database.
Table = "api_keys"
// UserTable is the table that holds the user relation/edge.
@@ -51,6 +53,13 @@ const (
GroupInverseTable = "groups"
// GroupColumn is the table column denoting the group relation/edge.
GroupColumn = "group_id"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "api_key_id"
)
// Columns holds all SQL columns for apikey fields.
@@ -161,6 +170,20 @@ func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption {
sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...))
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
@@ -175,3 +198,10 @@ func newGroupStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn),
)
}
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}

View File

@@ -516,6 +516,29 @@ func HasGroupWith(preds ...predicate.Group) predicate.ApiKey {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.ApiKey {
return predicate.ApiKey(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.ApiKey {
return predicate.ApiKey(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.ApiKey) predicate.ApiKey {
return predicate.ApiKey(sql.AndPredicates(predicates...))

View File

@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
)
@@ -122,6 +123,21 @@ func (_c *ApiKeyCreate) SetGroup(v *Group) *ApiKeyCreate {
return _c.SetGroupID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *ApiKeyCreate) AddUsageLogIDs(ids ...int64) *ApiKeyCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *ApiKeyCreate) AddUsageLogs(v ...*UsageLog) *ApiKeyCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the ApiKeyMutation object of the builder.
func (_c *ApiKeyCreate) Mutation() *ApiKeyMutation {
return _c.mutation
@@ -303,6 +319,22 @@ func (_c *ApiKeyCreate) createSpec() (*ApiKey, *sqlgraph.CreateSpec) {
_node.GroupID = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}

View File

@@ -4,6 +4,7 @@ package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
@@ -14,18 +15,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// ApiKeyQuery is the builder for querying ApiKey entities.
type ApiKeyQuery struct {
config
ctx *QueryContext
order []apikey.OrderOption
inters []Interceptor
predicates []predicate.ApiKey
withUser *UserQuery
withGroup *GroupQuery
ctx *QueryContext
order []apikey.OrderOption
inters []Interceptor
predicates []predicate.ApiKey
withUser *UserQuery
withGroup *GroupQuery
withUsageLogs *UsageLogQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -106,6 +109,28 @@ func (_q *ApiKeyQuery) QueryGroup() *GroupQuery {
return query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *ApiKeyQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(apikey.Table, apikey.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, apikey.UsageLogsTable, apikey.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first ApiKey entity from the query.
// Returns a *NotFoundError when no ApiKey was found.
func (_q *ApiKeyQuery) First(ctx context.Context) (*ApiKey, error) {
@@ -293,13 +318,14 @@ func (_q *ApiKeyQuery) Clone() *ApiKeyQuery {
return nil
}
return &ApiKeyQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]apikey.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.ApiKey{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(),
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]apikey.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.ApiKey{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
@@ -328,6 +354,17 @@ func (_q *ApiKeyQuery) WithGroup(opts ...func(*GroupQuery)) *ApiKeyQuery {
return _q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *ApiKeyQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *ApiKeyQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -406,9 +443,10 @@ func (_q *ApiKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ApiKe
var (
nodes = []*ApiKey{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
loadedTypes = [3]bool{
_q.withUser != nil,
_q.withGroup != nil,
_q.withUsageLogs != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
@@ -441,6 +479,13 @@ func (_q *ApiKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ApiKe
return nil, err
}
}
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *ApiKey) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *ApiKey, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
@@ -505,6 +550,36 @@ func (_q *ApiKeyQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes [
}
return nil
}
func (_q *ApiKeyQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*ApiKey, init func(*ApiKey), assign func(*ApiKey, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*ApiKey)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldAPIKeyID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(apikey.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.APIKeyID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "api_key_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *ApiKeyQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
)
@@ -142,6 +143,21 @@ func (_u *ApiKeyUpdate) SetGroup(v *Group) *ApiKeyUpdate {
return _u.SetGroupID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *ApiKeyUpdate) AddUsageLogIDs(ids ...int64) *ApiKeyUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *ApiKeyUpdate) AddUsageLogs(v ...*UsageLog) *ApiKeyUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the ApiKeyMutation object of the builder.
func (_u *ApiKeyUpdate) Mutation() *ApiKeyMutation {
return _u.mutation
@@ -159,6 +175,27 @@ func (_u *ApiKeyUpdate) ClearGroup() *ApiKeyUpdate {
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *ApiKeyUpdate) ClearUsageLogs() *ApiKeyUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *ApiKeyUpdate) RemoveUsageLogIDs(ids ...int64) *ApiKeyUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *ApiKeyUpdate) RemoveUsageLogs(v ...*UsageLog) *ApiKeyUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *ApiKeyUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -312,6 +349,51 @@ func (_u *ApiKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{apikey.Label}
@@ -444,6 +526,21 @@ func (_u *ApiKeyUpdateOne) SetGroup(v *Group) *ApiKeyUpdateOne {
return _u.SetGroupID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *ApiKeyUpdateOne) AddUsageLogIDs(ids ...int64) *ApiKeyUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *ApiKeyUpdateOne) AddUsageLogs(v ...*UsageLog) *ApiKeyUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the ApiKeyMutation object of the builder.
func (_u *ApiKeyUpdateOne) Mutation() *ApiKeyMutation {
return _u.mutation
@@ -461,6 +558,27 @@ func (_u *ApiKeyUpdateOne) ClearGroup() *ApiKeyUpdateOne {
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *ApiKeyUpdateOne) ClearUsageLogs() *ApiKeyUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *ApiKeyUpdateOne) RemoveUsageLogIDs(ids ...int64) *ApiKeyUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *ApiKeyUpdateOne) RemoveUsageLogs(v ...*UsageLog) *ApiKeyUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the ApiKeyUpdate builder.
func (_u *ApiKeyUpdateOne) Where(ps ...predicate.ApiKey) *ApiKeyUpdateOne {
_u.mutation.Where(ps...)
@@ -644,6 +762,51 @@ func (_u *ApiKeyUpdateOne) sqlSave(ctx context.Context) (_node *ApiKey, err erro
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: apikey.UsageLogsTable,
Columns: []string{apikey.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &ApiKey{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -48,6 +49,8 @@ type Client struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
User *UserClient
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
@@ -72,6 +75,7 @@ func (c *Client) init() {
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.Setting = NewSettingClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
c.User = NewUserClient(c.config)
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
c.UserSubscription = NewUserSubscriptionClient(c.config)
@@ -174,6 +178,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
@@ -203,6 +208,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
@@ -236,7 +242,7 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
c.User, c.UserAllowedGroup, c.UserSubscription,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -247,7 +253,7 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting,
c.User, c.UserAllowedGroup, c.UserSubscription,
c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -270,6 +276,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.RedeemCode.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
case *UsageLogMutation:
return c.UsageLog.mutate(ctx, m)
case *UserMutation:
return c.User.mutate(ctx, m)
case *UserAllowedGroupMutation:
@@ -405,6 +413,38 @@ func (c *AccountClient) QueryGroups(_m *Account) *GroupQuery {
return query
}
// QueryProxy queries the proxy edge of a Account.
func (c *AccountClient) QueryProxy(_m *Account) *ProxyQuery {
query := (&ProxyClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(account.Table, account.FieldID, id),
sqlgraph.To(proxy.Table, proxy.FieldID),
sqlgraph.Edge(sqlgraph.M2O, false, account.ProxyTable, account.ProxyColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUsageLogs queries the usage_logs edge of a Account.
func (c *AccountClient) QueryUsageLogs(_m *Account) *UsageLogQuery {
query := (&UsageLogClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(account.Table, account.FieldID, id),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, account.UsageLogsTable, account.UsageLogsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAccountGroups queries the account_groups edge of a Account.
func (c *AccountClient) QueryAccountGroups(_m *Account) *AccountGroupQuery {
query := (&AccountGroupClient{config: c.config}).Query()
@@ -704,6 +744,22 @@ func (c *ApiKeyClient) QueryGroup(_m *ApiKey) *GroupQuery {
return query
}
// QueryUsageLogs queries the usage_logs edge of a ApiKey.
func (c *ApiKeyClient) QueryUsageLogs(_m *ApiKey) *UsageLogQuery {
query := (&UsageLogClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(apikey.Table, apikey.FieldID, id),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, apikey.UsageLogsTable, apikey.UsageLogsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *ApiKeyClient) Hooks() []Hook {
hooks := c.hooks.ApiKey
@@ -887,6 +943,22 @@ func (c *GroupClient) QuerySubscriptions(_m *Group) *UserSubscriptionQuery {
return query
}
// QueryUsageLogs queries the usage_logs edge of a Group.
func (c *GroupClient) QueryUsageLogs(_m *Group) *UsageLogQuery {
query := (&UsageLogClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(group.Table, group.FieldID, id),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, group.UsageLogsTable, group.UsageLogsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAccounts queries the accounts edge of a Group.
func (c *GroupClient) QueryAccounts(_m *Group) *AccountQuery {
query := (&AccountClient{config: c.config}).Query()
@@ -1086,6 +1158,22 @@ func (c *ProxyClient) GetX(ctx context.Context, id int64) *Proxy {
return obj
}
// QueryAccounts queries the accounts edge of a Proxy.
func (c *ProxyClient) QueryAccounts(_m *Proxy) *AccountQuery {
query := (&AccountClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(proxy.Table, proxy.FieldID, id),
sqlgraph.To(account.Table, account.FieldID),
sqlgraph.Edge(sqlgraph.O2M, true, proxy.AccountsTable, proxy.AccountsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *ProxyClient) Hooks() []Hook {
hooks := c.hooks.Proxy
@@ -1411,6 +1499,219 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
}
}
// UsageLogClient is a client for the UsageLog schema.
type UsageLogClient struct {
config
}
// NewUsageLogClient returns a client for the UsageLog from the given config.
func NewUsageLogClient(c config) *UsageLogClient {
return &UsageLogClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `usagelog.Hooks(f(g(h())))`.
func (c *UsageLogClient) Use(hooks ...Hook) {
c.hooks.UsageLog = append(c.hooks.UsageLog, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `usagelog.Intercept(f(g(h())))`.
func (c *UsageLogClient) Intercept(interceptors ...Interceptor) {
c.inters.UsageLog = append(c.inters.UsageLog, interceptors...)
}
// Create returns a builder for creating a UsageLog entity.
func (c *UsageLogClient) Create() *UsageLogCreate {
mutation := newUsageLogMutation(c.config, OpCreate)
return &UsageLogCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UsageLog entities.
func (c *UsageLogClient) CreateBulk(builders ...*UsageLogCreate) *UsageLogCreateBulk {
return &UsageLogCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UsageLogClient) MapCreateBulk(slice any, setFunc func(*UsageLogCreate, int)) *UsageLogCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UsageLogCreateBulk{err: fmt.Errorf("calling to UsageLogClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UsageLogCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UsageLogCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UsageLog.
func (c *UsageLogClient) Update() *UsageLogUpdate {
mutation := newUsageLogMutation(c.config, OpUpdate)
return &UsageLogUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UsageLogClient) UpdateOne(_m *UsageLog) *UsageLogUpdateOne {
mutation := newUsageLogMutation(c.config, OpUpdateOne, withUsageLog(_m))
return &UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UsageLogClient) UpdateOneID(id int64) *UsageLogUpdateOne {
mutation := newUsageLogMutation(c.config, OpUpdateOne, withUsageLogID(id))
return &UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UsageLog.
func (c *UsageLogClient) Delete() *UsageLogDelete {
mutation := newUsageLogMutation(c.config, OpDelete)
return &UsageLogDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UsageLogClient) DeleteOne(_m *UsageLog) *UsageLogDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UsageLogClient) DeleteOneID(id int64) *UsageLogDeleteOne {
builder := c.Delete().Where(usagelog.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UsageLogDeleteOne{builder}
}
// Query returns a query builder for UsageLog.
func (c *UsageLogClient) Query() *UsageLogQuery {
return &UsageLogQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUsageLog},
inters: c.Interceptors(),
}
}
// Get returns a UsageLog entity by its id.
func (c *UsageLogClient) Get(ctx context.Context, id int64) (*UsageLog, error) {
return c.Query().Where(usagelog.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UsageLogClient) GetX(ctx context.Context, id int64) *UsageLog {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryUser queries the user edge of a UsageLog.
func (c *UsageLogClient) QueryUser(_m *UsageLog) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.UserTable, usagelog.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAPIKey queries the api_key edge of a UsageLog.
func (c *UsageLogClient) QueryAPIKey(_m *UsageLog) *ApiKeyQuery {
query := (&ApiKeyClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, id),
sqlgraph.To(apikey.Table, apikey.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.APIKeyTable, usagelog.APIKeyColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAccount queries the account edge of a UsageLog.
func (c *UsageLogClient) QueryAccount(_m *UsageLog) *AccountQuery {
query := (&AccountClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, id),
sqlgraph.To(account.Table, account.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.AccountTable, usagelog.AccountColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryGroup queries the group edge of a UsageLog.
func (c *UsageLogClient) QueryGroup(_m *UsageLog) *GroupQuery {
query := (&GroupClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, id),
sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.GroupTable, usagelog.GroupColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QuerySubscription queries the subscription edge of a UsageLog.
func (c *UsageLogClient) QuerySubscription(_m *UsageLog) *UserSubscriptionQuery {
query := (&UserSubscriptionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, id),
sqlgraph.To(usersubscription.Table, usersubscription.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.SubscriptionTable, usagelog.SubscriptionColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UsageLogClient) Hooks() []Hook {
return c.hooks.UsageLog
}
// Interceptors returns the client interceptors.
func (c *UsageLogClient) Interceptors() []Interceptor {
return c.inters.UsageLog
}
func (c *UsageLogClient) mutate(ctx context.Context, m *UsageLogMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UsageLogCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UsageLogUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UsageLogDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UsageLog mutation op: %q", m.Op())
}
}
// UserClient is a client for the User schema.
type UserClient struct {
config
@@ -1599,6 +1900,22 @@ func (c *UserClient) QueryAllowedGroups(_m *User) *GroupQuery {
return query
}
// QueryUsageLogs queries the usage_logs edge of a User.
func (c *UserClient) QueryUsageLogs(_m *User) *UsageLogQuery {
query := (&UsageLogClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -1914,14 +2231,32 @@ func (c *UserSubscriptionClient) QueryAssignedByUser(_m *UserSubscription) *User
return query
}
// QueryUsageLogs queries the usage_logs edge of a UserSubscription.
func (c *UserSubscriptionClient) QueryUsageLogs(_m *UserSubscription) *UsageLogQuery {
query := (&UsageLogClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, id),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserSubscriptionClient) Hooks() []Hook {
return c.hooks.UserSubscription
hooks := c.hooks.UserSubscription
return append(hooks[:len(hooks):len(hooks)], usersubscription.Hooks[:]...)
}
// Interceptors returns the client interceptors.
func (c *UserSubscriptionClient) Interceptors() []Interceptor {
return c.inters.UserSubscription
inters := c.inters.UserSubscription
return append(inters[:len(inters):len(inters)], usersubscription.Interceptors[:]...)
}
func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscriptionMutation) (Value, error) {
@@ -1942,16 +2277,15 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, User,
UserAllowedGroup, UserSubscription []ent.Hook
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
User, UserAllowedGroup, UserSubscription []ent.Hook
}
inters struct {
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, User,
UserAllowedGroup, UserSubscription []ent.Interceptor
Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog,
User, UserAllowedGroup, UserSubscription []ent.Interceptor
}
)
// ExecContext 透传到底层 driver用于在 ent 事务中执行原生 SQL例如同步 legacy 字段)。
// ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it.
// See, database/sql#DB.ExecContext for more information.
func (c *config) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
@@ -1964,7 +2298,6 @@ func (c *config) ExecContext(ctx context.Context, query string, args ...any) (st
return ex.ExecContext(ctx, query, args...)
}
// QueryContext 透传到底层 driver用于在事务内执行原生查询并共享锁/一致性语义。
// QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it.
// See, database/sql#DB.QueryContext for more information.
func (c *config) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) {

View File

@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -89,6 +90,7 @@ func checkColumn(t, c string) error {
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,

View File

@@ -43,6 +43,8 @@ type Group struct {
WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"`
// MonthlyLimitUsd holds the value of the "monthly_limit_usd" field.
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
// DefaultValidityDays holds the value of the "default_validity_days" field.
DefaultValidityDays int `json:"default_validity_days,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -57,6 +59,8 @@ type GroupEdges struct {
RedeemCodes []*RedeemCode `json:"redeem_codes,omitempty"`
// Subscriptions holds the value of the subscriptions edge.
Subscriptions []*UserSubscription `json:"subscriptions,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// Accounts holds the value of the accounts edge.
Accounts []*Account `json:"accounts,omitempty"`
// AllowedUsers holds the value of the allowed_users edge.
@@ -67,7 +71,7 @@ type GroupEdges struct {
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [7]bool
loadedTypes [8]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -97,10 +101,19 @@ func (e GroupEdges) SubscriptionsOrErr() ([]*UserSubscription, error) {
return nil, &NotLoadedError{edge: "subscriptions"}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e GroupEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[3] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// AccountsOrErr returns the Accounts value or an error if the edge
// was not loaded in eager-loading.
func (e GroupEdges) AccountsOrErr() ([]*Account, error) {
if e.loadedTypes[3] {
if e.loadedTypes[4] {
return e.Accounts, nil
}
return nil, &NotLoadedError{edge: "accounts"}
@@ -109,7 +122,7 @@ func (e GroupEdges) AccountsOrErr() ([]*Account, error) {
// AllowedUsersOrErr returns the AllowedUsers value or an error if the edge
// was not loaded in eager-loading.
func (e GroupEdges) AllowedUsersOrErr() ([]*User, error) {
if e.loadedTypes[4] {
if e.loadedTypes[5] {
return e.AllowedUsers, nil
}
return nil, &NotLoadedError{edge: "allowed_users"}
@@ -118,7 +131,7 @@ func (e GroupEdges) AllowedUsersOrErr() ([]*User, error) {
// AccountGroupsOrErr returns the AccountGroups value or an error if the edge
// was not loaded in eager-loading.
func (e GroupEdges) AccountGroupsOrErr() ([]*AccountGroup, error) {
if e.loadedTypes[5] {
if e.loadedTypes[6] {
return e.AccountGroups, nil
}
return nil, &NotLoadedError{edge: "account_groups"}
@@ -127,7 +140,7 @@ func (e GroupEdges) AccountGroupsOrErr() ([]*AccountGroup, error) {
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e GroupEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[6] {
if e.loadedTypes[7] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -142,7 +155,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd:
values[i] = new(sql.NullFloat64)
case group.FieldID:
case group.FieldID, group.FieldDefaultValidityDays:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString)
@@ -252,6 +265,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.MonthlyLimitUsd = new(float64)
*_m.MonthlyLimitUsd = value.Float64
}
case group.FieldDefaultValidityDays:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field default_validity_days", values[i])
} else if value.Valid {
_m.DefaultValidityDays = int(value.Int64)
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -280,6 +299,11 @@ func (_m *Group) QuerySubscriptions() *UserSubscriptionQuery {
return NewGroupClient(_m.config).QuerySubscriptions(_m)
}
// QueryUsageLogs queries the "usage_logs" edge of the Group entity.
func (_m *Group) QueryUsageLogs() *UsageLogQuery {
return NewGroupClient(_m.config).QueryUsageLogs(_m)
}
// QueryAccounts queries the "accounts" edge of the Group entity.
func (_m *Group) QueryAccounts() *AccountQuery {
return NewGroupClient(_m.config).QueryAccounts(_m)
@@ -371,6 +395,9 @@ func (_m *Group) String() string {
builder.WriteString("monthly_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("default_validity_days=")
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -41,12 +41,16 @@ const (
FieldWeeklyLimitUsd = "weekly_limit_usd"
// FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database.
FieldMonthlyLimitUsd = "monthly_limit_usd"
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
FieldDefaultValidityDays = "default_validity_days"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
EdgeRedeemCodes = "redeem_codes"
// EdgeSubscriptions holds the string denoting the subscriptions edge name in mutations.
EdgeSubscriptions = "subscriptions"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// EdgeAccounts holds the string denoting the accounts edge name in mutations.
EdgeAccounts = "accounts"
// EdgeAllowedUsers holds the string denoting the allowed_users edge name in mutations.
@@ -78,6 +82,13 @@ const (
SubscriptionsInverseTable = "user_subscriptions"
// SubscriptionsColumn is the table column denoting the subscriptions relation/edge.
SubscriptionsColumn = "group_id"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "group_id"
// AccountsTable is the table that holds the accounts relation/edge. The primary key declared below.
AccountsTable = "account_groups"
// AccountsInverseTable is the table name for the Account entity.
@@ -120,6 +131,7 @@ var Columns = []string{
FieldDailyLimitUsd,
FieldWeeklyLimitUsd,
FieldMonthlyLimitUsd,
FieldDefaultValidityDays,
}
var (
@@ -173,6 +185,8 @@ var (
DefaultSubscriptionType string
// SubscriptionTypeValidator is a validator for the "subscription_type" field. It is called by the builders before save.
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
)
// OrderOption defines the ordering options for the Group queries.
@@ -248,6 +262,11 @@ func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc()
}
// ByDefaultValidityDays orders the results by the default_validity_days field.
func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -290,6 +309,20 @@ func BySubscriptions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByAccountsCount orders the results by accounts count.
func ByAccountsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -366,6 +399,13 @@ func newSubscriptionsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, SubscriptionsTable, SubscriptionsColumn),
)
}
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
func newAccountsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@@ -120,6 +120,11 @@ func MonthlyLimitUsd(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldMonthlyLimitUsd, v))
}
// DefaultValidityDays applies equality check predicate on the "default_validity_days" field. It's identical to DefaultValidityDaysEQ.
func DefaultValidityDays(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -785,6 +790,46 @@ func MonthlyLimitUsdNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldMonthlyLimitUsd))
}
// DefaultValidityDaysEQ applies the EQ predicate on the "default_validity_days" field.
func DefaultValidityDaysEQ(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
}
// DefaultValidityDaysNEQ applies the NEQ predicate on the "default_validity_days" field.
func DefaultValidityDaysNEQ(v int) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldDefaultValidityDays, v))
}
// DefaultValidityDaysIn applies the In predicate on the "default_validity_days" field.
func DefaultValidityDaysIn(vs ...int) predicate.Group {
return predicate.Group(sql.FieldIn(FieldDefaultValidityDays, vs...))
}
// DefaultValidityDaysNotIn applies the NotIn predicate on the "default_validity_days" field.
func DefaultValidityDaysNotIn(vs ...int) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldDefaultValidityDays, vs...))
}
// DefaultValidityDaysGT applies the GT predicate on the "default_validity_days" field.
func DefaultValidityDaysGT(v int) predicate.Group {
return predicate.Group(sql.FieldGT(FieldDefaultValidityDays, v))
}
// DefaultValidityDaysGTE applies the GTE predicate on the "default_validity_days" field.
func DefaultValidityDaysGTE(v int) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldDefaultValidityDays, v))
}
// DefaultValidityDaysLT applies the LT predicate on the "default_validity_days" field.
func DefaultValidityDaysLT(v int) predicate.Group {
return predicate.Group(sql.FieldLT(FieldDefaultValidityDays, v))
}
// DefaultValidityDaysLTE applies the LTE predicate on the "default_validity_days" field.
func DefaultValidityDaysLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
@@ -854,6 +899,29 @@ func HasSubscriptionsWith(preds ...predicate.UserSubscription) predicate.Group {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.Group {
return predicate.Group(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasAccounts applies the HasEdge predicate on the "accounts" edge.
func HasAccounts() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -201,6 +202,20 @@ func (_c *GroupCreate) SetNillableMonthlyLimitUsd(v *float64) *GroupCreate {
return _c
}
// SetDefaultValidityDays sets the "default_validity_days" field.
func (_c *GroupCreate) SetDefaultValidityDays(v int) *GroupCreate {
_c.mutation.SetDefaultValidityDays(v)
return _c
}
// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil.
func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
if v != nil {
_c.SetDefaultValidityDays(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -246,6 +261,21 @@ func (_c *GroupCreate) AddSubscriptions(v ...*UserSubscription) *GroupCreate {
return _c.AddSubscriptionIDs(ids...)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *GroupCreate) AddUsageLogIDs(ids ...int64) *GroupCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *GroupCreate) AddUsageLogs(v ...*UsageLog) *GroupCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// AddAccountIDs adds the "accounts" edge to the Account entity by IDs.
func (_c *GroupCreate) AddAccountIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAccountIDs(ids...)
@@ -347,6 +377,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultSubscriptionType
_c.mutation.SetSubscriptionType(v)
}
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
return nil
}
@@ -396,6 +430,9 @@ func (_c *GroupCreate) check() error {
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
}
}
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
return nil
}
@@ -475,6 +512,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldMonthlyLimitUsd, field.TypeFloat64, value)
_node.MonthlyLimitUsd = &value
}
if value, ok := _c.mutation.DefaultValidityDays(); ok {
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
_node.DefaultValidityDays = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -523,6 +564,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AccountsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -813,6 +870,24 @@ func (u *GroupUpsert) ClearMonthlyLimitUsd() *GroupUpsert {
return u
}
// SetDefaultValidityDays sets the "default_validity_days" field.
func (u *GroupUpsert) SetDefaultValidityDays(v int) *GroupUpsert {
u.Set(group.FieldDefaultValidityDays, v)
return u
}
// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create.
func (u *GroupUpsert) UpdateDefaultValidityDays() *GroupUpsert {
u.SetExcluded(group.FieldDefaultValidityDays)
return u
}
// AddDefaultValidityDays adds v to the "default_validity_days" field.
func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
u.Add(group.FieldDefaultValidityDays, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1089,6 +1164,27 @@ func (u *GroupUpsertOne) ClearMonthlyLimitUsd() *GroupUpsertOne {
})
}
// SetDefaultValidityDays sets the "default_validity_days" field.
func (u *GroupUpsertOne) SetDefaultValidityDays(v int) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetDefaultValidityDays(v)
})
}
// AddDefaultValidityDays adds v to the "default_validity_days" field.
func (u *GroupUpsertOne) AddDefaultValidityDays(v int) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddDefaultValidityDays(v)
})
}
// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateDefaultValidityDays()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1531,6 +1627,27 @@ func (u *GroupUpsertBulk) ClearMonthlyLimitUsd() *GroupUpsertBulk {
})
}
// SetDefaultValidityDays sets the "default_validity_days" field.
func (u *GroupUpsertBulk) SetDefaultValidityDays(v int) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetDefaultValidityDays(v)
})
}
// AddDefaultValidityDays adds v to the "default_validity_days" field.
func (u *GroupUpsertBulk) AddDefaultValidityDays(v int) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddDefaultValidityDays(v)
})
}
// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateDefaultValidityDays()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -33,6 +34,7 @@ type GroupQuery struct {
withAPIKeys *ApiKeyQuery
withRedeemCodes *RedeemCodeQuery
withSubscriptions *UserSubscriptionQuery
withUsageLogs *UsageLogQuery
withAccounts *AccountQuery
withAllowedUsers *UserQuery
withAccountGroups *AccountGroupQuery
@@ -139,6 +141,28 @@ func (_q *GroupQuery) QuerySubscriptions() *UserSubscriptionQuery {
return query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *GroupQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(group.Table, group.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, group.UsageLogsTable, group.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAccounts chains the current query on the "accounts" edge.
func (_q *GroupQuery) QueryAccounts() *AccountQuery {
query := (&AccountClient{config: _q.config}).Query()
@@ -422,6 +446,7 @@ func (_q *GroupQuery) Clone() *GroupQuery {
withAPIKeys: _q.withAPIKeys.Clone(),
withRedeemCodes: _q.withRedeemCodes.Clone(),
withSubscriptions: _q.withSubscriptions.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAccounts: _q.withAccounts.Clone(),
withAllowedUsers: _q.withAllowedUsers.Clone(),
withAccountGroups: _q.withAccountGroups.Clone(),
@@ -465,6 +490,17 @@ func (_q *GroupQuery) WithSubscriptions(opts ...func(*UserSubscriptionQuery)) *G
return _q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *GroupQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *GroupQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// WithAccounts tells the query-builder to eager-load the nodes that are connected to
// the "accounts" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *GroupQuery) WithAccounts(opts ...func(*AccountQuery)) *GroupQuery {
@@ -587,10 +623,11 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group,
var (
nodes = []*Group{}
_spec = _q.querySpec()
loadedTypes = [7]bool{
loadedTypes = [8]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
_q.withUsageLogs != nil,
_q.withAccounts != nil,
_q.withAllowedUsers != nil,
_q.withAccountGroups != nil,
@@ -636,6 +673,13 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group,
return nil, err
}
}
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *Group) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *Group, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
if query := _q.withAccounts; query != nil {
if err := _q.loadAccounts(ctx, query, nodes,
func(n *Group) { n.Edges.Accounts = []*Account{} },
@@ -763,6 +807,39 @@ func (_q *GroupQuery) loadSubscriptions(ctx context.Context, query *UserSubscrip
}
return nil
}
func (_q *GroupQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*Group, init func(*Group), assign func(*Group, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*Group)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldGroupID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(group.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.GroupID
if fk == nil {
return fmt.Errorf(`foreign-key "group_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "group_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *GroupQuery) loadAccounts(ctx context.Context, query *AccountQuery, nodes []*Group, init func(*Group), assign func(*Group, *Account)) error {
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int64]*Group)

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -251,6 +252,27 @@ func (_u *GroupUpdate) ClearMonthlyLimitUsd() *GroupUpdate {
return _u
}
// SetDefaultValidityDays sets the "default_validity_days" field.
func (_u *GroupUpdate) SetDefaultValidityDays(v int) *GroupUpdate {
_u.mutation.ResetDefaultValidityDays()
_u.mutation.SetDefaultValidityDays(v)
return _u
}
// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableDefaultValidityDays(v *int) *GroupUpdate {
if v != nil {
_u.SetDefaultValidityDays(*v)
}
return _u
}
// AddDefaultValidityDays adds value to the "default_validity_days" field.
func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
_u.mutation.AddDefaultValidityDays(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -296,6 +318,21 @@ func (_u *GroupUpdate) AddSubscriptions(v ...*UserSubscription) *GroupUpdate {
return _u.AddSubscriptionIDs(ids...)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *GroupUpdate) AddUsageLogIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *GroupUpdate) AddUsageLogs(v ...*UsageLog) *GroupUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// AddAccountIDs adds the "accounts" edge to the Account entity by IDs.
func (_u *GroupUpdate) AddAccountIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAccountIDs(ids...)
@@ -394,6 +431,27 @@ func (_u *GroupUpdate) RemoveSubscriptions(v ...*UserSubscription) *GroupUpdate
return _u.RemoveSubscriptionIDs(ids...)
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *GroupUpdate) ClearUsageLogs() *GroupUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *GroupUpdate) RemoveUsageLogIDs(ids ...int64) *GroupUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *GroupUpdate) RemoveUsageLogs(v ...*UsageLog) *GroupUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// ClearAccounts clears all "accounts" edges to the Account entity.
func (_u *GroupUpdate) ClearAccounts() *GroupUpdate {
_u.mutation.ClearAccounts()
@@ -578,6 +636,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.MonthlyLimitUsdCleared() {
_spec.ClearField(group.FieldMonthlyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.DefaultValidityDays(); ok {
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -713,6 +777,51 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AccountsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -1065,6 +1174,27 @@ func (_u *GroupUpdateOne) ClearMonthlyLimitUsd() *GroupUpdateOne {
return _u
}
// SetDefaultValidityDays sets the "default_validity_days" field.
func (_u *GroupUpdateOne) SetDefaultValidityDays(v int) *GroupUpdateOne {
_u.mutation.ResetDefaultValidityDays()
_u.mutation.SetDefaultValidityDays(v)
return _u
}
// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableDefaultValidityDays(v *int) *GroupUpdateOne {
if v != nil {
_u.SetDefaultValidityDays(*v)
}
return _u
}
// AddDefaultValidityDays adds value to the "default_validity_days" field.
func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
_u.mutation.AddDefaultValidityDays(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1110,6 +1240,21 @@ func (_u *GroupUpdateOne) AddSubscriptions(v ...*UserSubscription) *GroupUpdateO
return _u.AddSubscriptionIDs(ids...)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *GroupUpdateOne) AddUsageLogIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *GroupUpdateOne) AddUsageLogs(v ...*UsageLog) *GroupUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// AddAccountIDs adds the "accounts" edge to the Account entity by IDs.
func (_u *GroupUpdateOne) AddAccountIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAccountIDs(ids...)
@@ -1208,6 +1353,27 @@ func (_u *GroupUpdateOne) RemoveSubscriptions(v ...*UserSubscription) *GroupUpda
return _u.RemoveSubscriptionIDs(ids...)
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *GroupUpdateOne) ClearUsageLogs() *GroupUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *GroupUpdateOne) RemoveUsageLogIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *GroupUpdateOne) RemoveUsageLogs(v ...*UsageLog) *GroupUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// ClearAccounts clears all "accounts" edges to the Account entity.
func (_u *GroupUpdateOne) ClearAccounts() *GroupUpdateOne {
_u.mutation.ClearAccounts()
@@ -1422,6 +1588,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.MonthlyLimitUsdCleared() {
_spec.ClearField(group.FieldMonthlyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.DefaultValidityDays(); ok {
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1557,6 +1729,51 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: group.UsageLogsTable,
Columns: []string{group.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AccountsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,

View File

@@ -93,6 +93,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
}
// The UsageLogFunc type is an adapter to allow the use of ordinary
// function as UsageLog mutator.
type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UsageLogFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UsageLogMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageLogMutation", m)
}
// The UserFunc type is an adapter to allow the use of ordinary
// function as User mutator.
type UserFunc func(context.Context, *ent.UserMutation) (ent.Value, error)

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -266,6 +267,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
}
// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UsageLogFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UsageLogQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageLogQuery", q)
}
// The TraverseUsageLog type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUsageLog func(context.Context, *ent.UsageLogQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUsageLog) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUsageLog) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UsageLogQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UsageLogQuery", q)
}
// The UserFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserFunc func(context.Context, *ent.UserQuery) (ent.Value, error)
@@ -364,6 +392,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
case *ent.UsageLogQuery:
return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil
case *ent.UserQuery:
return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil
case *ent.UserAllowedGroupQuery:

View File

@@ -20,7 +20,6 @@ var (
{Name: "type", Type: field.TypeString, Size: 20},
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "proxy_id", Type: field.TypeInt64, Nullable: true},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
@@ -33,12 +32,21 @@ var (
{Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20},
{Name: "proxy_id", Type: field.TypeInt64, Nullable: true},
}
// AccountsTable holds the schema information for the "accounts" table.
AccountsTable = &schema.Table{
Name: "accounts",
Columns: AccountsColumns,
PrimaryKey: []*schema.Column{AccountsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[21]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
},
Indexes: []*schema.Index{
{
Name: "account_platform",
@@ -53,42 +61,42 @@ var (
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[12]},
Columns: []*schema.Column{AccountsColumns[11]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[9]},
Columns: []*schema.Column{AccountsColumns[21]},
},
{
Name: "account_priority",
Unique: false,
Columns: []*schema.Column{AccountsColumns[11]},
Columns: []*schema.Column{AccountsColumns[10]},
},
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[14]},
Columns: []*schema.Column{AccountsColumns[13]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[15]},
Columns: []*schema.Column{AccountsColumns[14]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[16]},
Columns: []*schema.Column{AccountsColumns[15]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[17]},
Columns: []*schema.Column{AccountsColumns[16]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[18]},
Columns: []*schema.Column{AccountsColumns[17]},
},
{
Name: "account_deleted_at",
@@ -100,7 +108,7 @@ var (
// AccountGroupsColumns holds the columns for the "account_groups" table.
AccountGroupsColumns = []*schema.Column{
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "created_at", Type: field.TypeTime},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "account_id", Type: field.TypeInt64},
{Name: "group_id", Type: field.TypeInt64},
}
@@ -168,11 +176,6 @@ var (
},
},
Indexes: []*schema.Index{
{
Name: "apikey_key",
Unique: true,
Columns: []*schema.Column{APIKeysColumns[4]},
},
{
Name: "apikey_user_id",
Unique: false,
@@ -201,7 +204,7 @@ var (
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Unique: true, Size: 100},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "description", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "is_exclusive", Type: field.TypeBool, Default: false},
@@ -211,6 +214,7 @@ var (
{Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "default_validity_days", Type: field.TypeInt, Default: 30},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -218,11 +222,6 @@ var (
Columns: GroupsColumns,
PrimaryKey: []*schema.Column{GroupsColumns[0]},
Indexes: []*schema.Index{
{
Name: "group_name",
Unique: true,
Columns: []*schema.Column{GroupsColumns[4]},
},
{
Name: "group_status",
Unique: false,
@@ -316,11 +315,6 @@ var (
},
},
Indexes: []*schema.Index{
{
Name: "redeemcode_code",
Unique: true,
Columns: []*schema.Column{RedeemCodesColumns[1]},
},
{
Name: "redeemcode_status",
Unique: false,
@@ -350,11 +344,123 @@ var (
Name: "settings",
Columns: SettingsColumns,
PrimaryKey: []*schema.Column{SettingsColumns[0]},
}
// UsageLogsColumns holds the columns for the "usage_logs" table.
UsageLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_read_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_5m_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_1h_tokens", Type: field.TypeInt, Default: 0},
{Name: "input_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "output_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "cache_creation_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "cache_read_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
{Name: "stream", Type: field.TypeBool, Default: false},
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "account_id", Type: field.TypeInt64},
{Name: "api_key_id", Type: field.TypeInt64},
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "user_id", Type: field.TypeInt64},
{Name: "subscription_id", Type: field.TypeInt64, Nullable: true},
}
// UsageLogsTable holds the schema information for the "usage_logs" table.
UsageLogsTable = &schema.Table{
Name: "usage_logs",
Columns: UsageLogsColumns,
PrimaryKey: []*schema.Column{UsageLogsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[21]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[22]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[23]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[24]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[25]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
},
Indexes: []*schema.Index{
{
Name: "setting_key",
Unique: true,
Columns: []*schema.Column{SettingsColumns[1]},
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[22]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[21]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[23]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[20]},
},
{
Name: "usagelog_model",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[2]},
},
{
Name: "usagelog_request_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[1]},
},
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[20]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[22], UsageLogsColumns[20]},
},
},
}
@@ -364,7 +470,7 @@ var (
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "email", Type: field.TypeString, Unique: true, Size: 255},
{Name: "email", Type: field.TypeString, Size: 255},
{Name: "password_hash", Type: field.TypeString, Size: 255},
{Name: "role", Type: field.TypeString, Size: 20, Default: "user"},
{Name: "balance", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -380,11 +486,6 @@ var (
Columns: UsersColumns,
PrimaryKey: []*schema.Column{UsersColumns[0]},
Indexes: []*schema.Index{
{
Name: "user_email",
Unique: true,
Columns: []*schema.Column{UsersColumns[4]},
},
{
Name: "user_status",
Unique: false,
@@ -399,7 +500,7 @@ var (
}
// UserAllowedGroupsColumns holds the columns for the "user_allowed_groups" table.
UserAllowedGroupsColumns = []*schema.Column{
{Name: "created_at", Type: field.TypeTime},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "user_id", Type: field.TypeInt64},
{Name: "group_id", Type: field.TypeInt64},
}
@@ -435,6 +536,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "starts_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
@@ -458,19 +560,19 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "user_subscriptions_groups_subscriptions",
Columns: []*schema.Column{UserSubscriptionsColumns[14]},
Columns: []*schema.Column{UserSubscriptionsColumns[15]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "user_subscriptions_users_subscriptions",
Columns: []*schema.Column{UserSubscriptionsColumns[15]},
Columns: []*schema.Column{UserSubscriptionsColumns[16]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "user_subscriptions_users_assigned_subscriptions",
Columns: []*schema.Column{UserSubscriptionsColumns[16]},
Columns: []*schema.Column{UserSubscriptionsColumns[17]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.SetNull,
},
@@ -479,32 +581,37 @@ var (
{
Name: "usersubscription_user_id",
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[15]},
Columns: []*schema.Column{UserSubscriptionsColumns[16]},
},
{
Name: "usersubscription_group_id",
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[14]},
Columns: []*schema.Column{UserSubscriptionsColumns[15]},
},
{
Name: "usersubscription_status",
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[5]},
Columns: []*schema.Column{UserSubscriptionsColumns[6]},
},
{
Name: "usersubscription_expires_at",
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[4]},
Columns: []*schema.Column{UserSubscriptionsColumns[5]},
},
{
Name: "usersubscription_assigned_by",
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[16]},
Columns: []*schema.Column{UserSubscriptionsColumns[17]},
},
{
Name: "usersubscription_user_id_group_id",
Unique: true,
Columns: []*schema.Column{UserSubscriptionsColumns[15], UserSubscriptionsColumns[14]},
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[15]},
},
{
Name: "usersubscription_deleted_at",
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[3]},
},
},
}
@@ -517,6 +624,7 @@ var (
ProxiesTable,
RedeemCodesTable,
SettingsTable,
UsageLogsTable,
UsersTable,
UserAllowedGroupsTable,
UserSubscriptionsTable,
@@ -524,6 +632,7 @@ var (
)
func init() {
AccountsTable.ForeignKeys[0].RefTable = ProxiesTable
AccountsTable.Annotation = &entsql.Annotation{
Table: "accounts",
}
@@ -551,6 +660,14 @@ func init() {
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}
UsageLogsTable.ForeignKeys[0].RefTable = AccountsTable
UsageLogsTable.ForeignKeys[1].RefTable = APIKeysTable
UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable
UsageLogsTable.ForeignKeys[3].RefTable = UsersTable
UsageLogsTable.ForeignKeys[4].RefTable = UserSubscriptionsTable
UsageLogsTable.Annotation = &entsql.Annotation{
Table: "usage_logs",
}
UsersTable.Annotation = &entsql.Annotation{
Table: "users",
}

File diff suppressed because it is too large Load Diff

View File

@@ -27,6 +27,9 @@ type RedeemCode func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)
// UsageLog is the predicate function for usagelog builders.
type UsageLog func(*sql.Selector)
// User is the predicate function for user builders.
type User func(*sql.Selector)

View File

@@ -36,10 +36,31 @@ type Proxy struct {
// Password holds the value of the "password" field.
Password *string `json:"password,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
Status string `json:"status,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the ProxyQuery when eager-loading is set.
Edges ProxyEdges `json:"edges"`
selectValues sql.SelectValues
}
// ProxyEdges holds the relations/edges for other nodes in the graph.
type ProxyEdges struct {
// Accounts holds the value of the accounts edge.
Accounts []*Account `json:"accounts,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// AccountsOrErr returns the Accounts value or an error if the edge
// was not loaded in eager-loading.
func (e ProxyEdges) AccountsOrErr() ([]*Account, error) {
if e.loadedTypes[0] {
return e.Accounts, nil
}
return nil, &NotLoadedError{edge: "accounts"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Proxy) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
@@ -148,6 +169,11 @@ func (_m *Proxy) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryAccounts queries the "accounts" edge of the Proxy entity.
func (_m *Proxy) QueryAccounts() *AccountQuery {
return NewProxyClient(_m.config).QueryAccounts(_m)
}
// Update returns a builder for updating this Proxy.
// Note that you need to call Proxy.Unwrap() before calling this method if this Proxy
// was returned from a transaction, and the transaction was committed or rolled back.

View File

@@ -7,6 +7,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
@@ -34,8 +35,17 @@ const (
FieldPassword = "password"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// EdgeAccounts holds the string denoting the accounts edge name in mutations.
EdgeAccounts = "accounts"
// Table holds the table name of the proxy in the database.
Table = "proxies"
// AccountsTable is the table that holds the accounts relation/edge.
AccountsTable = "accounts"
// AccountsInverseTable is the table name for the Account entity.
// It exists in this package in order to avoid circular dependency with the "account" package.
AccountsInverseTable = "accounts"
// AccountsColumn is the table column denoting the accounts relation/edge.
AccountsColumn = "proxy_id"
)
// Columns holds all SQL columns for proxy fields.
@@ -150,3 +160,24 @@ func ByPassword(opts ...sql.OrderTermOption) OrderOption {
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByAccountsCount orders the results by accounts count.
func ByAccountsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAccountsStep(), opts...)
}
}
// ByAccounts orders the results by accounts terms.
func ByAccounts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAccountsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newAccountsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AccountsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, true, AccountsTable, AccountsColumn),
)
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
@@ -684,6 +685,29 @@ func StatusContainsFold(v string) predicate.Proxy {
return predicate.Proxy(sql.FieldContainsFold(FieldStatus, v))
}
// HasAccounts applies the HasEdge predicate on the "accounts" edge.
func HasAccounts() predicate.Proxy {
return predicate.Proxy(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, true, AccountsTable, AccountsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAccountsWith applies the HasEdge predicate on the "accounts" edge with a given conditions (other predicates).
func HasAccountsWith(preds ...predicate.Account) predicate.Proxy {
return predicate.Proxy(func(s *sql.Selector) {
step := newAccountsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Proxy) predicate.Proxy {
return predicate.Proxy(sql.AndPredicates(predicates...))

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/proxy"
)
@@ -130,6 +131,21 @@ func (_c *ProxyCreate) SetNillableStatus(v *string) *ProxyCreate {
return _c
}
// AddAccountIDs adds the "accounts" edge to the Account entity by IDs.
func (_c *ProxyCreate) AddAccountIDs(ids ...int64) *ProxyCreate {
_c.mutation.AddAccountIDs(ids...)
return _c
}
// AddAccounts adds the "accounts" edges to the Account entity.
func (_c *ProxyCreate) AddAccounts(v ...*Account) *ProxyCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAccountIDs(ids...)
}
// Mutation returns the ProxyMutation object of the builder.
func (_c *ProxyCreate) Mutation() *ProxyMutation {
return _c.mutation
@@ -308,6 +324,22 @@ func (_c *ProxyCreate) createSpec() (*Proxy, *sqlgraph.CreateSpec) {
_spec.SetField(proxy.FieldStatus, field.TypeString, value)
_node.Status = value
}
if nodes := _c.mutation.AccountsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}

View File

@@ -4,6 +4,7 @@ package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
@@ -11,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/proxy"
)
@@ -18,10 +20,11 @@ import (
// ProxyQuery is the builder for querying Proxy entities.
type ProxyQuery struct {
config
ctx *QueryContext
order []proxy.OrderOption
inters []Interceptor
predicates []predicate.Proxy
ctx *QueryContext
order []proxy.OrderOption
inters []Interceptor
predicates []predicate.Proxy
withAccounts *AccountQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -58,6 +61,28 @@ func (_q *ProxyQuery) Order(o ...proxy.OrderOption) *ProxyQuery {
return _q
}
// QueryAccounts chains the current query on the "accounts" edge.
func (_q *ProxyQuery) QueryAccounts() *AccountQuery {
query := (&AccountClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(proxy.Table, proxy.FieldID, selector),
sqlgraph.To(account.Table, account.FieldID),
sqlgraph.Edge(sqlgraph.O2M, true, proxy.AccountsTable, proxy.AccountsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first Proxy entity from the query.
// Returns a *NotFoundError when no Proxy was found.
func (_q *ProxyQuery) First(ctx context.Context) (*Proxy, error) {
@@ -245,17 +270,29 @@ func (_q *ProxyQuery) Clone() *ProxyQuery {
return nil
}
return &ProxyQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]proxy.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.Proxy{}, _q.predicates...),
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]proxy.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.Proxy{}, _q.predicates...),
withAccounts: _q.withAccounts.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithAccounts tells the query-builder to eager-load the nodes that are connected to
// the "accounts" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *ProxyQuery) WithAccounts(opts ...func(*AccountQuery)) *ProxyQuery {
query := (&AccountClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAccounts = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -332,8 +369,11 @@ func (_q *ProxyQuery) prepareQuery(ctx context.Context) error {
func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, error) {
var (
nodes = []*Proxy{}
_spec = _q.querySpec()
nodes = []*Proxy{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withAccounts != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*Proxy).scanValues(nil, columns)
@@ -341,6 +381,7 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy,
_spec.Assign = func(columns []string, values []any) error {
node := &Proxy{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
@@ -352,9 +393,50 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy,
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withAccounts; query != nil {
if err := _q.loadAccounts(ctx, query, nodes,
func(n *Proxy) { n.Edges.Accounts = []*Account{} },
func(n *Proxy, e *Account) { n.Edges.Accounts = append(n.Edges.Accounts, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *ProxyQuery) loadAccounts(ctx context.Context, query *AccountQuery, nodes []*Proxy, init func(*Proxy), assign func(*Proxy, *Account)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*Proxy)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(account.FieldProxyID)
}
query.Where(predicate.Account(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(proxy.AccountsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.ProxyID
if fk == nil {
return fmt.Errorf(`foreign-key "proxy_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "proxy_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *ProxyQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/proxy"
)
@@ -171,11 +172,47 @@ func (_u *ProxyUpdate) SetNillableStatus(v *string) *ProxyUpdate {
return _u
}
// AddAccountIDs adds the "accounts" edge to the Account entity by IDs.
func (_u *ProxyUpdate) AddAccountIDs(ids ...int64) *ProxyUpdate {
_u.mutation.AddAccountIDs(ids...)
return _u
}
// AddAccounts adds the "accounts" edges to the Account entity.
func (_u *ProxyUpdate) AddAccounts(v ...*Account) *ProxyUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAccountIDs(ids...)
}
// Mutation returns the ProxyMutation object of the builder.
func (_u *ProxyUpdate) Mutation() *ProxyMutation {
return _u.mutation
}
// ClearAccounts clears all "accounts" edges to the Account entity.
func (_u *ProxyUpdate) ClearAccounts() *ProxyUpdate {
_u.mutation.ClearAccounts()
return _u
}
// RemoveAccountIDs removes the "accounts" edge to Account entities by IDs.
func (_u *ProxyUpdate) RemoveAccountIDs(ids ...int64) *ProxyUpdate {
_u.mutation.RemoveAccountIDs(ids...)
return _u
}
// RemoveAccounts removes "accounts" edges to Account entities.
func (_u *ProxyUpdate) RemoveAccounts(v ...*Account) *ProxyUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAccountIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *ProxyUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -304,6 +341,51 @@ func (_u *ProxyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(proxy.FieldStatus, field.TypeString, value)
}
if _u.mutation.AccountsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAccountsIDs(); len(nodes) > 0 && !_u.mutation.AccountsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AccountsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{proxy.Label}
@@ -467,11 +549,47 @@ func (_u *ProxyUpdateOne) SetNillableStatus(v *string) *ProxyUpdateOne {
return _u
}
// AddAccountIDs adds the "accounts" edge to the Account entity by IDs.
func (_u *ProxyUpdateOne) AddAccountIDs(ids ...int64) *ProxyUpdateOne {
_u.mutation.AddAccountIDs(ids...)
return _u
}
// AddAccounts adds the "accounts" edges to the Account entity.
func (_u *ProxyUpdateOne) AddAccounts(v ...*Account) *ProxyUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAccountIDs(ids...)
}
// Mutation returns the ProxyMutation object of the builder.
func (_u *ProxyUpdateOne) Mutation() *ProxyMutation {
return _u.mutation
}
// ClearAccounts clears all "accounts" edges to the Account entity.
func (_u *ProxyUpdateOne) ClearAccounts() *ProxyUpdateOne {
_u.mutation.ClearAccounts()
return _u
}
// RemoveAccountIDs removes the "accounts" edge to Account entities by IDs.
func (_u *ProxyUpdateOne) RemoveAccountIDs(ids ...int64) *ProxyUpdateOne {
_u.mutation.RemoveAccountIDs(ids...)
return _u
}
// RemoveAccounts removes "accounts" edges to Account entities.
func (_u *ProxyUpdateOne) RemoveAccounts(v ...*Account) *ProxyUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAccountIDs(ids...)
}
// Where appends a list predicates to the ProxyUpdate builder.
func (_u *ProxyUpdateOne) Where(ps ...predicate.Proxy) *ProxyUpdateOne {
_u.mutation.Where(ps...)
@@ -630,6 +748,51 @@ func (_u *ProxyUpdateOne) sqlSave(ctx context.Context) (_node *Proxy, err error)
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(proxy.FieldStatus, field.TypeString, value)
}
if _u.mutation.AccountsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAccountsIDs(); len(nodes) > 0 && !_u.mutation.AccountsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AccountsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: true,
Table: proxy.AccountsTable,
Columns: []string{proxy.AccountsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &Proxy{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -259,6 +260,10 @@ func init() {
group.DefaultSubscriptionType = groupDescSubscriptionType.Default.(string)
// group.SubscriptionTypeValidator is a validator for the "subscription_type" field. It is called by the builders before save.
group.SubscriptionTypeValidator = groupDescSubscriptionType.Validators[0].(func(string) error)
// groupDescDefaultValidityDays is the schema descriptor for default_validity_days field.
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
proxyMixin := schema.Proxy{}.Mixin()
proxyMixinHooks1 := proxyMixin[1].Hooks()
proxy.Hooks[0] = proxyMixinHooks1[0]
@@ -410,16 +415,114 @@ func init() {
return nil
}
}()
// settingDescValue is the schema descriptor for value field.
settingDescValue := settingFields[1].Descriptor()
// setting.ValueValidator is a validator for the "value" field. It is called by the builders before save.
setting.ValueValidator = settingDescValue.Validators[0].(func(string) error)
// settingDescUpdatedAt is the schema descriptor for updated_at field.
settingDescUpdatedAt := settingFields[2].Descriptor()
// setting.DefaultUpdatedAt holds the default value on creation for the updated_at field.
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
usagelogFields := schema.UsageLog{}.Fields()
_ = usagelogFields
// usagelogDescRequestID is the schema descriptor for request_id field.
usagelogDescRequestID := usagelogFields[3].Descriptor()
// usagelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
usagelog.RequestIDValidator = func() func(string) error {
validators := usagelogDescRequestID.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(request_id string) error {
for _, fn := range fns {
if err := fn(request_id); err != nil {
return err
}
}
return nil
}
}()
// usagelogDescModel is the schema descriptor for model field.
usagelogDescModel := usagelogFields[4].Descriptor()
// usagelog.ModelValidator is a validator for the "model" field. It is called by the builders before save.
usagelog.ModelValidator = func() func(string) error {
validators := usagelogDescModel.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(model string) error {
for _, fn := range fns {
if err := fn(model); err != nil {
return err
}
}
return nil
}
}()
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[7].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[13].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[14].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[17].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[18].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[20].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[21].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[24].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
userMixinHooks1 := userMixin[1].Hooks()
user.Hooks[0] = userMixinHooks1[0]
@@ -518,6 +621,10 @@ func init() {
// userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time)
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]
usersubscriptionMixinInters1 := usersubscriptionMixin[1].Interceptors()
usersubscription.Interceptors[0] = usersubscriptionMixinInters1[0]
usersubscriptionMixinFields0 := usersubscriptionMixin[0].Fields()
_ = usersubscriptionMixinFields0
usersubscriptionFields := schema.UserSubscription{}.Fields()

View File

@@ -168,6 +168,13 @@ func (Account) Edges() []ent.Edge {
// 一个账户可以属于多个分组,一个分组可以包含多个账户
edge.To("groups", Group.Type).
Through("account_groups", AccountGroup.Type),
// proxy: 账户使用的代理配置(可选的一对一关系)
// 使用已有的 proxy_id 外键字段
edge.To("proxy", Proxy.Type).
Field("proxy_id").
Unique(),
// usage_logs: 该账户的使用日志
edge.To("usage_logs", UsageLog.Type),
}
}

View File

@@ -4,6 +4,7 @@ import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
@@ -33,7 +34,8 @@ func (AccountGroup) Fields() []ent.Field {
Default(50),
field.Time("created_at").
Immutable().
Default(time.Now),
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}

View File

@@ -60,12 +60,13 @@ func (ApiKey) Edges() []ent.Edge {
Ref("api_keys").
Field("group_id").
Unique(),
edge.To("usage_logs", UsageLog.Type),
}
}
func (ApiKey) Indexes() []ent.Index {
return []ent.Index{
index.Fields("key").Unique(),
// key 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("user_id"),
index.Fields("group_id"),
index.Fields("status"),

View File

@@ -33,10 +33,11 @@ func (Group) Mixin() []ent.Mixin {
func (Group) Fields() []ent.Field {
return []ent.Field{
// 唯一约束通过部分索引实现WHERE deleted_at IS NULL支持软删除后重用
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
field.String("name").
MaxLen(100).
NotEmpty().
Unique(),
NotEmpty(),
field.String("description").
Optional().
Nillable().
@@ -69,6 +70,8 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Int("default_validity_days").
Default(30),
}
}
@@ -77,6 +80,7 @@ func (Group) Edges() []ent.Edge {
edge.To("api_keys", ApiKey.Type),
edge.To("redeem_codes", RedeemCode.Type),
edge.To("subscriptions", UserSubscription.Type),
edge.To("usage_logs", UsageLog.Type),
edge.From("accounts", Account.Type).
Ref("groups").
Through("account_groups", AccountGroup.Type),
@@ -88,7 +92,7 @@ func (Group) Edges() []ent.Edge {
func (Group) Indexes() []ent.Index {
return []ent.Index{
index.Fields("name").Unique(),
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("status"),
index.Fields("platform"),
index.Fields("subscription_type"),

View File

@@ -12,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/intercept"
)
@@ -112,6 +113,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
SetOp(ent.Op)
SetDeletedAt(time.Time)
WhereP(...func(*sql.Selector))
Client() *dbent.Client
})
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
@@ -122,7 +124,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
mx.SetOp(ent.OpUpdate)
// 设置删除时间为当前时间
mx.SetDeletedAt(time.Now())
return next.Mutate(ctx, m)
return mx.Client().Mutate(ctx, m)
})
},
}

View File

@@ -6,6 +6,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
@@ -54,6 +55,15 @@ func (Proxy) Fields() []ent.Field {
}
}
// Edges 定义代理实体的关联关系。
func (Proxy) Edges() []ent.Edge {
return []ent.Edge{
// accounts: 使用此代理的账户(反向边)
edge.From("accounts", Account.Type).
Ref("proxy"),
}
}
func (Proxy) Indexes() []ent.Index {
return []ent.Index{
index.Fields("status"),

View File

@@ -15,6 +15,14 @@ import (
)
// RedeemCode holds the schema definition for the RedeemCode entity.
//
// 删除策略:硬删除
// RedeemCode 使用硬删除而非软删除,原因如下:
// - 兑换码具有一次性使用特性,删除后无需保留历史记录
// - 已使用的兑换码通过 status 和 used_at 字段追踪,无需依赖软删除
// - 减少数据库存储压力和查询复杂度
//
// 如需审计已删除的兑换码,建议在删除前将关键信息写入审计日志表。
type RedeemCode struct {
ent.Schema
}
@@ -78,7 +86,7 @@ func (RedeemCode) Edges() []ent.Edge {
func (RedeemCode) Indexes() []ent.Index {
return []ent.Index{
index.Fields("code").Unique(),
// code 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("status"),
index.Fields("used_by"),
index.Fields("group_id"),

View File

@@ -8,10 +8,17 @@ import (
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// Setting holds the schema definition for the Setting entity.
//
// 删除策略:硬删除
// Setting 使用硬删除而非软删除,原因如下:
// - 系统设置是简单的键值对,删除即意味着恢复默认值
// - 设置变更通常通过应用日志追踪,无需在数据库层面保留历史
// - 保持表结构简洁,避免无效数据积累
//
// 如需设置变更审计,建议在更新/删除前将变更记录写入审计日志表。
type Setting struct {
ent.Schema
}
@@ -29,7 +36,6 @@ func (Setting) Fields() []ent.Field {
NotEmpty().
Unique(),
field.String("value").
NotEmpty().
SchemaType(map[string]string{
dialect.Postgres: "text",
}),
@@ -43,7 +49,6 @@ func (Setting) Fields() []ent.Field {
}
func (Setting) Indexes() []ent.Index {
return []ent.Index{
index.Fields("key").Unique(),
}
// key 字段已在 Fields() 中声明 Unique(),无需额外索引
return nil
}

View File

@@ -0,0 +1,152 @@
// Package schema 定义 Ent ORM 的数据库 schema。
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UsageLog 定义使用日志实体的 schema。
//
// 使用日志记录每次 API 调用的详细信息,包括 token 使用量、成本计算等。
// 这是一个只追加的表,不支持更新和删除。
type UsageLog struct {
ent.Schema
}
// Annotations 返回 schema 的注解配置。
func (UsageLog) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "usage_logs"},
}
}
// Fields 定义使用日志实体的所有字段。
func (UsageLog) Fields() []ent.Field {
return []ent.Field{
// 关联字段
field.Int64("user_id"),
field.Int64("api_key_id"),
field.Int64("account_id"),
field.String("request_id").
MaxLen(64).
NotEmpty(),
field.String("model").
MaxLen(100).
NotEmpty(),
field.Int64("group_id").
Optional().
Nillable(),
field.Int64("subscription_id").
Optional().
Nillable(),
// Token 计数字段
field.Int("input_tokens").
Default(0),
field.Int("output_tokens").
Default(0),
field.Int("cache_creation_tokens").
Default(0),
field.Int("cache_read_tokens").
Default(0),
field.Int("cache_creation_5m_tokens").
Default(0),
field.Int("cache_creation_1h_tokens").
Default(0),
// 成本字段
field.Float("input_cost").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("output_cost").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("cache_creation_cost").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("cache_read_cost").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("total_cost").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("actual_cost").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("rate_multiplier").
Default(1).
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// 其他字段
field.Int8("billing_type").
Default(0),
field.Bool("stream").
Default(false),
field.Int("duration_ms").
Optional().
Nillable(),
field.Int("first_token_ms").
Optional().
Nillable(),
// 时间戳(只有 created_at日志不可修改
field.Time("created_at").
Default(time.Now).
Immutable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
// Edges 定义使用日志实体的关联关系。
func (UsageLog) Edges() []ent.Edge {
return []ent.Edge{
edge.From("user", User.Type).
Ref("usage_logs").
Field("user_id").
Required().
Unique(),
edge.From("api_key", ApiKey.Type).
Ref("usage_logs").
Field("api_key_id").
Required().
Unique(),
edge.From("account", Account.Type).
Ref("usage_logs").
Field("account_id").
Required().
Unique(),
edge.From("group", Group.Type).
Ref("usage_logs").
Field("group_id").
Unique(),
edge.From("subscription", UserSubscription.Type).
Ref("usage_logs").
Field("subscription_id").
Unique(),
}
}
// Indexes 定义数据库索引,优化查询性能。
func (UsageLog) Indexes() []ent.Index {
return []ent.Index{
index.Fields("user_id"),
index.Fields("api_key_id"),
index.Fields("account_id"),
index.Fields("group_id"),
index.Fields("subscription_id"),
index.Fields("created_at"),
index.Fields("model"),
index.Fields("request_id"),
// 复合索引用于时间范围查询
index.Fields("user_id", "created_at"),
index.Fields("api_key_id", "created_at"),
}
}

View File

@@ -33,10 +33,11 @@ func (User) Mixin() []ent.Mixin {
func (User) Fields() []ent.Field {
return []ent.Field{
// 唯一约束通过部分索引实现WHERE deleted_at IS NULL支持软删除后重用
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
field.String("email").
MaxLen(255).
NotEmpty().
Unique(),
NotEmpty(),
field.String("password_hash").
MaxLen(255).
NotEmpty(),
@@ -73,12 +74,13 @@ func (User) Edges() []ent.Edge {
edge.To("assigned_subscriptions", UserSubscription.Type),
edge.To("allowed_groups", Group.Type).
Through("user_allowed_groups", UserAllowedGroup.Type),
edge.To("usage_logs", UsageLog.Type),
}
}
func (User) Indexes() []ent.Index {
return []ent.Index{
index.Fields("email").Unique(),
// email 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("status"),
index.Fields("deleted_at"),
}

View File

@@ -4,6 +4,7 @@ import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
@@ -31,7 +32,8 @@ func (UserAllowedGroup) Fields() []ent.Field {
field.Int64("group_id"),
field.Time("created_at").
Immutable().
Default(time.Now),
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}

View File

@@ -29,6 +29,7 @@ func (UserSubscription) Annotations() []schema.Annotation {
func (UserSubscription) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
mixins.SoftDeleteMixin{},
}
}
@@ -97,6 +98,7 @@ func (UserSubscription) Edges() []ent.Edge {
Ref("assigned_subscriptions").
Field("assigned_by").
Unique(),
edge.To("usage_logs", UsageLog.Type),
}
}
@@ -107,6 +109,9 @@ func (UserSubscription) Indexes() []ent.Index {
index.Fields("status"),
index.Fields("expires_at"),
index.Fields("assigned_by"),
index.Fields("user_id", "group_id").Unique(),
// 唯一约束通过部分索引实现WHERE deleted_at IS NULL支持软删除后重新订阅
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
index.Fields("user_id", "group_id"),
index.Fields("deleted_at"),
}
}

View File

@@ -44,8 +44,6 @@ func ValidColumn(column string) bool {
var (
// KeyValidator is a validator for the "key" field. It is called by the builders before save.
KeyValidator func(string) error
// ValueValidator is a validator for the "value" field. It is called by the builders before save.
ValueValidator func(string) error
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.

View File

@@ -102,11 +102,6 @@ func (_c *SettingCreate) check() error {
if _, ok := _c.mutation.Value(); !ok {
return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "Setting.value"`)}
}
if v, ok := _c.mutation.Value(); ok {
if err := setting.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Setting.value": %w`, err)}
}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Setting.updated_at"`)}
}

View File

@@ -110,11 +110,6 @@ func (_u *SettingUpdate) check() error {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "Setting.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := setting.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Setting.value": %w`, err)}
}
}
return nil
}
@@ -254,11 +249,6 @@ func (_u *SettingUpdateOne) check() error {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "Setting.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := setting.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Setting.value": %w`, err)}
}
}
return nil
}

View File

@@ -28,6 +28,8 @@ type Tx struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
User *UserClient
// UserAllowedGroup is the client for interacting with the UserAllowedGroup builders.
@@ -172,6 +174,7 @@ func (tx *Tx) init() {
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config)
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
@@ -238,7 +241,6 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error
var _ dialect.Driver = (*txDriver)(nil)
// ExecContext 透传到底层事务,用于在 ent 事务中执行原生 SQL与 ent 写入保持同一事务)。
// ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it.
// See, database/sql#Tx.ExecContext for more information.
func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
@@ -251,7 +253,6 @@ func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any)
return ex.ExecContext(ctx, query, args...)
}
// QueryContext 透传到底层事务,用于在 ent 事务中执行原生查询并共享锁语义。
// QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it.
// See, database/sql#Tx.QueryContext for more information.
func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) {

491
backend/ent/usagelog.go Normal file
View File

@@ -0,0 +1,491 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UsageLog is the model entity for the UsageLog schema.
type UsageLog struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// APIKeyID holds the value of the "api_key_id" field.
APIKeyID int64 `json:"api_key_id,omitempty"`
// AccountID holds the value of the "account_id" field.
AccountID int64 `json:"account_id,omitempty"`
// RequestID holds the value of the "request_id" field.
RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field.
Model string `json:"model,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
SubscriptionID *int64 `json:"subscription_id,omitempty"`
// InputTokens holds the value of the "input_tokens" field.
InputTokens int `json:"input_tokens,omitempty"`
// OutputTokens holds the value of the "output_tokens" field.
OutputTokens int `json:"output_tokens,omitempty"`
// CacheCreationTokens holds the value of the "cache_creation_tokens" field.
CacheCreationTokens int `json:"cache_creation_tokens,omitempty"`
// CacheReadTokens holds the value of the "cache_read_tokens" field.
CacheReadTokens int `json:"cache_read_tokens,omitempty"`
// CacheCreation5mTokens holds the value of the "cache_creation_5m_tokens" field.
CacheCreation5mTokens int `json:"cache_creation_5m_tokens,omitempty"`
// CacheCreation1hTokens holds the value of the "cache_creation_1h_tokens" field.
CacheCreation1hTokens int `json:"cache_creation_1h_tokens,omitempty"`
// InputCost holds the value of the "input_cost" field.
InputCost float64 `json:"input_cost,omitempty"`
// OutputCost holds the value of the "output_cost" field.
OutputCost float64 `json:"output_cost,omitempty"`
// CacheCreationCost holds the value of the "cache_creation_cost" field.
CacheCreationCost float64 `json:"cache_creation_cost,omitempty"`
// CacheReadCost holds the value of the "cache_read_cost" field.
CacheReadCost float64 `json:"cache_read_cost,omitempty"`
// TotalCost holds the value of the "total_cost" field.
TotalCost float64 `json:"total_cost,omitempty"`
// ActualCost holds the value of the "actual_cost" field.
ActualCost float64 `json:"actual_cost,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// BillingType holds the value of the "billing_type" field.
BillingType int8 `json:"billing_type,omitempty"`
// Stream holds the value of the "stream" field.
Stream bool `json:"stream,omitempty"`
// DurationMs holds the value of the "duration_ms" field.
DurationMs *int `json:"duration_ms,omitempty"`
// FirstTokenMs holds the value of the "first_token_ms" field.
FirstTokenMs *int `json:"first_token_ms,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UsageLogQuery when eager-loading is set.
Edges UsageLogEdges `json:"edges"`
selectValues sql.SelectValues
}
// UsageLogEdges holds the relations/edges for other nodes in the graph.
type UsageLogEdges struct {
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// APIKey holds the value of the api_key edge.
APIKey *ApiKey `json:"api_key,omitempty"`
// Account holds the value of the account edge.
Account *Account `json:"account,omitempty"`
// Group holds the value of the group edge.
Group *Group `json:"group,omitempty"`
// Subscription holds the value of the subscription edge.
Subscription *UserSubscription `json:"subscription,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [5]bool
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UsageLogEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// APIKeyOrErr returns the APIKey value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UsageLogEdges) APIKeyOrErr() (*ApiKey, error) {
if e.APIKey != nil {
return e.APIKey, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: apikey.Label}
}
return nil, &NotLoadedError{edge: "api_key"}
}
// AccountOrErr returns the Account value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UsageLogEdges) AccountOrErr() (*Account, error) {
if e.Account != nil {
return e.Account, nil
} else if e.loadedTypes[2] {
return nil, &NotFoundError{label: account.Label}
}
return nil, &NotLoadedError{edge: "account"}
}
// GroupOrErr returns the Group value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UsageLogEdges) GroupOrErr() (*Group, error) {
if e.Group != nil {
return e.Group, nil
} else if e.loadedTypes[3] {
return nil, &NotFoundError{label: group.Label}
}
return nil, &NotLoadedError{edge: "group"}
}
// SubscriptionOrErr returns the Subscription value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UsageLogEdges) SubscriptionOrErr() (*UserSubscription, error) {
if e.Subscription != nil {
return e.Subscription, nil
} else if e.loadedTypes[4] {
return nil, &NotFoundError{label: usersubscription.Label}
}
return nil, &NotLoadedError{edge: "subscription"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UsageLog fields.
func (_m *UsageLog) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case usagelog.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case usagelog.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case usagelog.FieldAPIKeyID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field api_key_id", values[i])
} else if value.Valid {
_m.APIKeyID = value.Int64
}
case usagelog.FieldAccountID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field account_id", values[i])
} else if value.Valid {
_m.AccountID = value.Int64
}
case usagelog.FieldRequestID:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field request_id", values[i])
} else if value.Valid {
_m.RequestID = value.String
}
case usagelog.FieldModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field model", values[i])
} else if value.Valid {
_m.Model = value.String
}
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
} else if value.Valid {
_m.GroupID = new(int64)
*_m.GroupID = value.Int64
}
case usagelog.FieldSubscriptionID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field subscription_id", values[i])
} else if value.Valid {
_m.SubscriptionID = new(int64)
*_m.SubscriptionID = value.Int64
}
case usagelog.FieldInputTokens:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field input_tokens", values[i])
} else if value.Valid {
_m.InputTokens = int(value.Int64)
}
case usagelog.FieldOutputTokens:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field output_tokens", values[i])
} else if value.Valid {
_m.OutputTokens = int(value.Int64)
}
case usagelog.FieldCacheCreationTokens:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field cache_creation_tokens", values[i])
} else if value.Valid {
_m.CacheCreationTokens = int(value.Int64)
}
case usagelog.FieldCacheReadTokens:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field cache_read_tokens", values[i])
} else if value.Valid {
_m.CacheReadTokens = int(value.Int64)
}
case usagelog.FieldCacheCreation5mTokens:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field cache_creation_5m_tokens", values[i])
} else if value.Valid {
_m.CacheCreation5mTokens = int(value.Int64)
}
case usagelog.FieldCacheCreation1hTokens:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field cache_creation_1h_tokens", values[i])
} else if value.Valid {
_m.CacheCreation1hTokens = int(value.Int64)
}
case usagelog.FieldInputCost:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field input_cost", values[i])
} else if value.Valid {
_m.InputCost = value.Float64
}
case usagelog.FieldOutputCost:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field output_cost", values[i])
} else if value.Valid {
_m.OutputCost = value.Float64
}
case usagelog.FieldCacheCreationCost:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field cache_creation_cost", values[i])
} else if value.Valid {
_m.CacheCreationCost = value.Float64
}
case usagelog.FieldCacheReadCost:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field cache_read_cost", values[i])
} else if value.Valid {
_m.CacheReadCost = value.Float64
}
case usagelog.FieldTotalCost:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field total_cost", values[i])
} else if value.Valid {
_m.TotalCost = value.Float64
}
case usagelog.FieldActualCost:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field actual_cost", values[i])
} else if value.Valid {
_m.ActualCost = value.Float64
}
case usagelog.FieldRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case usagelog.FieldBillingType:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
} else if value.Valid {
_m.BillingType = int8(value.Int64)
}
case usagelog.FieldStream:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field stream", values[i])
} else if value.Valid {
_m.Stream = value.Bool
}
case usagelog.FieldDurationMs:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field duration_ms", values[i])
} else if value.Valid {
_m.DurationMs = new(int)
*_m.DurationMs = int(value.Int64)
}
case usagelog.FieldFirstTokenMs:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field first_token_ms", values[i])
} else if value.Valid {
_m.FirstTokenMs = new(int)
*_m.FirstTokenMs = int(value.Int64)
}
case usagelog.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UsageLog.
// This includes values selected through modifiers, order, etc.
func (_m *UsageLog) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryUser queries the "user" edge of the UsageLog entity.
func (_m *UsageLog) QueryUser() *UserQuery {
return NewUsageLogClient(_m.config).QueryUser(_m)
}
// QueryAPIKey queries the "api_key" edge of the UsageLog entity.
func (_m *UsageLog) QueryAPIKey() *ApiKeyQuery {
return NewUsageLogClient(_m.config).QueryAPIKey(_m)
}
// QueryAccount queries the "account" edge of the UsageLog entity.
func (_m *UsageLog) QueryAccount() *AccountQuery {
return NewUsageLogClient(_m.config).QueryAccount(_m)
}
// QueryGroup queries the "group" edge of the UsageLog entity.
func (_m *UsageLog) QueryGroup() *GroupQuery {
return NewUsageLogClient(_m.config).QueryGroup(_m)
}
// QuerySubscription queries the "subscription" edge of the UsageLog entity.
func (_m *UsageLog) QuerySubscription() *UserSubscriptionQuery {
return NewUsageLogClient(_m.config).QuerySubscription(_m)
}
// Update returns a builder for updating this UsageLog.
// Note that you need to call UsageLog.Unwrap() before calling this method if this UsageLog
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UsageLog) Update() *UsageLogUpdateOne {
return NewUsageLogClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UsageLog entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UsageLog) Unwrap() *UsageLog {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UsageLog is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UsageLog) String() string {
var builder strings.Builder
builder.WriteString("UsageLog(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("api_key_id=")
builder.WriteString(fmt.Sprintf("%v", _m.APIKeyID))
builder.WriteString(", ")
builder.WriteString("account_id=")
builder.WriteString(fmt.Sprintf("%v", _m.AccountID))
builder.WriteString(", ")
builder.WriteString("request_id=")
builder.WriteString(_m.RequestID)
builder.WriteString(", ")
builder.WriteString("model=")
builder.WriteString(_m.Model)
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SubscriptionID; v != nil {
builder.WriteString("subscription_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("input_tokens=")
builder.WriteString(fmt.Sprintf("%v", _m.InputTokens))
builder.WriteString(", ")
builder.WriteString("output_tokens=")
builder.WriteString(fmt.Sprintf("%v", _m.OutputTokens))
builder.WriteString(", ")
builder.WriteString("cache_creation_tokens=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheCreationTokens))
builder.WriteString(", ")
builder.WriteString("cache_read_tokens=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheReadTokens))
builder.WriteString(", ")
builder.WriteString("cache_creation_5m_tokens=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheCreation5mTokens))
builder.WriteString(", ")
builder.WriteString("cache_creation_1h_tokens=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheCreation1hTokens))
builder.WriteString(", ")
builder.WriteString("input_cost=")
builder.WriteString(fmt.Sprintf("%v", _m.InputCost))
builder.WriteString(", ")
builder.WriteString("output_cost=")
builder.WriteString(fmt.Sprintf("%v", _m.OutputCost))
builder.WriteString(", ")
builder.WriteString("cache_creation_cost=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheCreationCost))
builder.WriteString(", ")
builder.WriteString("cache_read_cost=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheReadCost))
builder.WriteString(", ")
builder.WriteString("total_cost=")
builder.WriteString(fmt.Sprintf("%v", _m.TotalCost))
builder.WriteString(", ")
builder.WriteString("actual_cost=")
builder.WriteString(fmt.Sprintf("%v", _m.ActualCost))
builder.WriteString(", ")
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
builder.WriteString("billing_type=")
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
builder.WriteString(", ")
builder.WriteString("stream=")
builder.WriteString(fmt.Sprintf("%v", _m.Stream))
builder.WriteString(", ")
if v := _m.DurationMs; v != nil {
builder.WriteString("duration_ms=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.FirstTokenMs; v != nil {
builder.WriteString("first_token_ms=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// UsageLogs is a parsable slice of UsageLog.
type UsageLogs []*UsageLog

View File

@@ -0,0 +1,396 @@
// Code generated by ent, DO NOT EDIT.
package usagelog
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the usagelog type in the database.
Label = "usage_log"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldAPIKeyID holds the string denoting the api_key_id field in the database.
FieldAPIKeyID = "api_key_id"
// FieldAccountID holds the string denoting the account_id field in the database.
FieldAccountID = "account_id"
// FieldRequestID holds the string denoting the request_id field in the database.
FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database.
FieldModel = "model"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
FieldSubscriptionID = "subscription_id"
// FieldInputTokens holds the string denoting the input_tokens field in the database.
FieldInputTokens = "input_tokens"
// FieldOutputTokens holds the string denoting the output_tokens field in the database.
FieldOutputTokens = "output_tokens"
// FieldCacheCreationTokens holds the string denoting the cache_creation_tokens field in the database.
FieldCacheCreationTokens = "cache_creation_tokens"
// FieldCacheReadTokens holds the string denoting the cache_read_tokens field in the database.
FieldCacheReadTokens = "cache_read_tokens"
// FieldCacheCreation5mTokens holds the string denoting the cache_creation_5m_tokens field in the database.
FieldCacheCreation5mTokens = "cache_creation_5m_tokens"
// FieldCacheCreation1hTokens holds the string denoting the cache_creation_1h_tokens field in the database.
FieldCacheCreation1hTokens = "cache_creation_1h_tokens"
// FieldInputCost holds the string denoting the input_cost field in the database.
FieldInputCost = "input_cost"
// FieldOutputCost holds the string denoting the output_cost field in the database.
FieldOutputCost = "output_cost"
// FieldCacheCreationCost holds the string denoting the cache_creation_cost field in the database.
FieldCacheCreationCost = "cache_creation_cost"
// FieldCacheReadCost holds the string denoting the cache_read_cost field in the database.
FieldCacheReadCost = "cache_read_cost"
// FieldTotalCost holds the string denoting the total_cost field in the database.
FieldTotalCost = "total_cost"
// FieldActualCost holds the string denoting the actual_cost field in the database.
FieldActualCost = "actual_cost"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldBillingType holds the string denoting the billing_type field in the database.
FieldBillingType = "billing_type"
// FieldStream holds the string denoting the stream field in the database.
FieldStream = "stream"
// FieldDurationMs holds the string denoting the duration_ms field in the database.
FieldDurationMs = "duration_ms"
// FieldFirstTokenMs holds the string denoting the first_token_ms field in the database.
FieldFirstTokenMs = "first_token_ms"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeAPIKey holds the string denoting the api_key edge name in mutations.
EdgeAPIKey = "api_key"
// EdgeAccount holds the string denoting the account edge name in mutations.
EdgeAccount = "account"
// EdgeGroup holds the string denoting the group edge name in mutations.
EdgeGroup = "group"
// EdgeSubscription holds the string denoting the subscription edge name in mutations.
EdgeSubscription = "subscription"
// Table holds the table name of the usagelog in the database.
Table = "usage_logs"
// UserTable is the table that holds the user relation/edge.
UserTable = "usage_logs"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
// APIKeyTable is the table that holds the api_key relation/edge.
APIKeyTable = "usage_logs"
// APIKeyInverseTable is the table name for the ApiKey entity.
// It exists in this package in order to avoid circular dependency with the "apikey" package.
APIKeyInverseTable = "api_keys"
// APIKeyColumn is the table column denoting the api_key relation/edge.
APIKeyColumn = "api_key_id"
// AccountTable is the table that holds the account relation/edge.
AccountTable = "usage_logs"
// AccountInverseTable is the table name for the Account entity.
// It exists in this package in order to avoid circular dependency with the "account" package.
AccountInverseTable = "accounts"
// AccountColumn is the table column denoting the account relation/edge.
AccountColumn = "account_id"
// GroupTable is the table that holds the group relation/edge.
GroupTable = "usage_logs"
// GroupInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package.
GroupInverseTable = "groups"
// GroupColumn is the table column denoting the group relation/edge.
GroupColumn = "group_id"
// SubscriptionTable is the table that holds the subscription relation/edge.
SubscriptionTable = "usage_logs"
// SubscriptionInverseTable is the table name for the UserSubscription entity.
// It exists in this package in order to avoid circular dependency with the "usersubscription" package.
SubscriptionInverseTable = "user_subscriptions"
// SubscriptionColumn is the table column denoting the subscription relation/edge.
SubscriptionColumn = "subscription_id"
)
// Columns holds all SQL columns for usagelog fields.
var Columns = []string{
FieldID,
FieldUserID,
FieldAPIKeyID,
FieldAccountID,
FieldRequestID,
FieldModel,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
FieldOutputTokens,
FieldCacheCreationTokens,
FieldCacheReadTokens,
FieldCacheCreation5mTokens,
FieldCacheCreation1hTokens,
FieldInputCost,
FieldOutputCost,
FieldCacheCreationCost,
FieldCacheReadCost,
FieldTotalCost,
FieldActualCost,
FieldRateMultiplier,
FieldBillingType,
FieldStream,
FieldDurationMs,
FieldFirstTokenMs,
FieldCreatedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save.
RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
DefaultOutputTokens int
// DefaultCacheCreationTokens holds the default value on creation for the "cache_creation_tokens" field.
DefaultCacheCreationTokens int
// DefaultCacheReadTokens holds the default value on creation for the "cache_read_tokens" field.
DefaultCacheReadTokens int
// DefaultCacheCreation5mTokens holds the default value on creation for the "cache_creation_5m_tokens" field.
DefaultCacheCreation5mTokens int
// DefaultCacheCreation1hTokens holds the default value on creation for the "cache_creation_1h_tokens" field.
DefaultCacheCreation1hTokens int
// DefaultInputCost holds the default value on creation for the "input_cost" field.
DefaultInputCost float64
// DefaultOutputCost holds the default value on creation for the "output_cost" field.
DefaultOutputCost float64
// DefaultCacheCreationCost holds the default value on creation for the "cache_creation_cost" field.
DefaultCacheCreationCost float64
// DefaultCacheReadCost holds the default value on creation for the "cache_read_cost" field.
DefaultCacheReadCost float64
// DefaultTotalCost holds the default value on creation for the "total_cost" field.
DefaultTotalCost float64
// DefaultActualCost holds the default value on creation for the "actual_cost" field.
DefaultActualCost float64
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
DefaultRateMultiplier float64
// DefaultBillingType holds the default value on creation for the "billing_type" field.
DefaultBillingType int8
// DefaultStream holds the default value on creation for the "stream" field.
DefaultStream bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
)
// OrderOption defines the ordering options for the UsageLog queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByAPIKeyID orders the results by the api_key_id field.
func ByAPIKeyID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAPIKeyID, opts...).ToFunc()
}
// ByAccountID orders the results by the account_id field.
func ByAccountID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAccountID, opts...).ToFunc()
}
// ByRequestID orders the results by the request_id field.
func ByRequestID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRequestID, opts...).ToFunc()
}
// ByModel orders the results by the model field.
func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
}
// BySubscriptionID orders the results by the subscription_id field.
func BySubscriptionID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSubscriptionID, opts...).ToFunc()
}
// ByInputTokens orders the results by the input_tokens field.
func ByInputTokens(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldInputTokens, opts...).ToFunc()
}
// ByOutputTokens orders the results by the output_tokens field.
func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldOutputTokens, opts...).ToFunc()
}
// ByCacheCreationTokens orders the results by the cache_creation_tokens field.
func ByCacheCreationTokens(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheCreationTokens, opts...).ToFunc()
}
// ByCacheReadTokens orders the results by the cache_read_tokens field.
func ByCacheReadTokens(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheReadTokens, opts...).ToFunc()
}
// ByCacheCreation5mTokens orders the results by the cache_creation_5m_tokens field.
func ByCacheCreation5mTokens(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheCreation5mTokens, opts...).ToFunc()
}
// ByCacheCreation1hTokens orders the results by the cache_creation_1h_tokens field.
func ByCacheCreation1hTokens(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheCreation1hTokens, opts...).ToFunc()
}
// ByInputCost orders the results by the input_cost field.
func ByInputCost(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldInputCost, opts...).ToFunc()
}
// ByOutputCost orders the results by the output_cost field.
func ByOutputCost(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldOutputCost, opts...).ToFunc()
}
// ByCacheCreationCost orders the results by the cache_creation_cost field.
func ByCacheCreationCost(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheCreationCost, opts...).ToFunc()
}
// ByCacheReadCost orders the results by the cache_read_cost field.
func ByCacheReadCost(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheReadCost, opts...).ToFunc()
}
// ByTotalCost orders the results by the total_cost field.
func ByTotalCost(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotalCost, opts...).ToFunc()
}
// ByActualCost orders the results by the actual_cost field.
func ByActualCost(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldActualCost, opts...).ToFunc()
}
// ByRateMultiplier orders the results by the rate_multiplier field.
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByBillingType orders the results by the billing_type field.
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
}
// ByStream orders the results by the stream field.
func ByStream(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStream, opts...).ToFunc()
}
// ByDurationMs orders the results by the duration_ms field.
func ByDurationMs(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDurationMs, opts...).ToFunc()
}
// ByFirstTokenMs orders the results by the first_token_ms field.
func ByFirstTokenMs(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFirstTokenMs, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
// ByAPIKeyField orders the results by api_key field.
func ByAPIKeyField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAPIKeyStep(), sql.OrderByField(field, opts...))
}
}
// ByAccountField orders the results by account field.
func ByAccountField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAccountStep(), sql.OrderByField(field, opts...))
}
}
// ByGroupField orders the results by group field.
func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...))
}
}
// BySubscriptionField orders the results by subscription field.
func BySubscriptionField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newSubscriptionStep(), sql.OrderByField(field, opts...))
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}
func newAPIKeyStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(APIKeyInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, APIKeyTable, APIKeyColumn),
)
}
func newAccountStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AccountInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AccountTable, AccountColumn),
)
}
func newGroupStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(GroupInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn),
)
}
func newSubscriptionStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(SubscriptionInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, SubscriptionTable, SubscriptionColumn),
)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
)
// UsageLogDelete is the builder for deleting a UsageLog entity.
type UsageLogDelete struct {
config
hooks []Hook
mutation *UsageLogMutation
}
// Where appends a list predicates to the UsageLogDelete builder.
func (_d *UsageLogDelete) Where(ps ...predicate.UsageLog) *UsageLogDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UsageLogDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UsageLogDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UsageLogDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(usagelog.Table, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UsageLogDeleteOne is the builder for deleting a single UsageLog entity.
type UsageLogDeleteOne struct {
_d *UsageLogDelete
}
// Where appends a list predicates to the UsageLogDelete builder.
func (_d *UsageLogDeleteOne) Where(ps ...predicate.UsageLog) *UsageLogDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UsageLogDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{usagelog.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UsageLogDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,912 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UsageLogQuery is the builder for querying UsageLog entities.
type UsageLogQuery struct {
config
ctx *QueryContext
order []usagelog.OrderOption
inters []Interceptor
predicates []predicate.UsageLog
withUser *UserQuery
withAPIKey *ApiKeyQuery
withAccount *AccountQuery
withGroup *GroupQuery
withSubscription *UserSubscriptionQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UsageLogQuery builder.
func (_q *UsageLogQuery) Where(ps ...predicate.UsageLog) *UsageLogQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UsageLogQuery) Limit(limit int) *UsageLogQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UsageLogQuery) Offset(offset int) *UsageLogQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UsageLogQuery) Unique(unique bool) *UsageLogQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UsageLogQuery) Order(o ...usagelog.OrderOption) *UsageLogQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UsageLogQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.UserTable, usagelog.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAPIKey chains the current query on the "api_key" edge.
func (_q *UsageLogQuery) QueryAPIKey() *ApiKeyQuery {
query := (&ApiKeyClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, selector),
sqlgraph.To(apikey.Table, apikey.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.APIKeyTable, usagelog.APIKeyColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAccount chains the current query on the "account" edge.
func (_q *UsageLogQuery) QueryAccount() *AccountQuery {
query := (&AccountClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, selector),
sqlgraph.To(account.Table, account.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.AccountTable, usagelog.AccountColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryGroup chains the current query on the "group" edge.
func (_q *UsageLogQuery) QueryGroup() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, selector),
sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.GroupTable, usagelog.GroupColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QuerySubscription chains the current query on the "subscription" edge.
func (_q *UsageLogQuery) QuerySubscription() *UserSubscriptionQuery {
query := (&UserSubscriptionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usagelog.Table, usagelog.FieldID, selector),
sqlgraph.To(usersubscription.Table, usersubscription.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usagelog.SubscriptionTable, usagelog.SubscriptionColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UsageLog entity from the query.
// Returns a *NotFoundError when no UsageLog was found.
func (_q *UsageLogQuery) First(ctx context.Context) (*UsageLog, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{usagelog.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UsageLogQuery) FirstX(ctx context.Context) *UsageLog {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UsageLog ID from the query.
// Returns a *NotFoundError when no UsageLog ID was found.
func (_q *UsageLogQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{usagelog.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UsageLogQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UsageLog entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UsageLog entity is found.
// Returns a *NotFoundError when no UsageLog entities are found.
func (_q *UsageLogQuery) Only(ctx context.Context) (*UsageLog, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{usagelog.Label}
default:
return nil, &NotSingularError{usagelog.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UsageLogQuery) OnlyX(ctx context.Context) *UsageLog {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UsageLog ID in the query.
// Returns a *NotSingularError when more than one UsageLog ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UsageLogQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{usagelog.Label}
default:
err = &NotSingularError{usagelog.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UsageLogQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UsageLogs.
func (_q *UsageLogQuery) All(ctx context.Context) ([]*UsageLog, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UsageLog, *UsageLogQuery]()
return withInterceptors[[]*UsageLog](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UsageLogQuery) AllX(ctx context.Context) []*UsageLog {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UsageLog IDs.
func (_q *UsageLogQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(usagelog.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UsageLogQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UsageLogQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UsageLogQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UsageLogQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UsageLogQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UsageLogQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UsageLogQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UsageLogQuery) Clone() *UsageLogQuery {
if _q == nil {
return nil
}
return &UsageLogQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]usagelog.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UsageLog{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withAPIKey: _q.withAPIKey.Clone(),
withAccount: _q.withAccount.Clone(),
withGroup: _q.withGroup.Clone(),
withSubscription: _q.withSubscription.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UsageLogQuery) WithUser(opts ...func(*UserQuery)) *UsageLogQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// WithAPIKey tells the query-builder to eager-load the nodes that are connected to
// the "api_key" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UsageLogQuery) WithAPIKey(opts ...func(*ApiKeyQuery)) *UsageLogQuery {
query := (&ApiKeyClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAPIKey = query
return _q
}
// WithAccount tells the query-builder to eager-load the nodes that are connected to
// the "account" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UsageLogQuery) WithAccount(opts ...func(*AccountQuery)) *UsageLogQuery {
query := (&AccountClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAccount = query
return _q
}
// WithGroup tells the query-builder to eager-load the nodes that are connected to
// the "group" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UsageLogQuery) WithGroup(opts ...func(*GroupQuery)) *UsageLogQuery {
query := (&GroupClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withGroup = query
return _q
}
// WithSubscription tells the query-builder to eager-load the nodes that are connected to
// the "subscription" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UsageLogQuery) WithSubscription(opts ...func(*UserSubscriptionQuery)) *UsageLogQuery {
query := (&UserSubscriptionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withSubscription = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// UserID int64 `json:"user_id,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UsageLog.Query().
// GroupBy(usagelog.FieldUserID).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UsageLogQuery) GroupBy(field string, fields ...string) *UsageLogGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UsageLogGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = usagelog.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// UserID int64 `json:"user_id,omitempty"`
// }
//
// client.UsageLog.Query().
// Select(usagelog.FieldUserID).
// Scan(ctx, &v)
func (_q *UsageLogQuery) Select(fields ...string) *UsageLogSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UsageLogSelect{UsageLogQuery: _q}
sbuild.label = usagelog.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UsageLogSelect configured with the given aggregations.
func (_q *UsageLogQuery) Aggregate(fns ...AggregateFunc) *UsageLogSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UsageLogQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !usagelog.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UsageLogQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageLog, error) {
var (
nodes = []*UsageLog{}
_spec = _q.querySpec()
loadedTypes = [5]bool{
_q.withUser != nil,
_q.withAPIKey != nil,
_q.withAccount != nil,
_q.withGroup != nil,
_q.withSubscription != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UsageLog).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UsageLog{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UsageLog, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
if query := _q.withAPIKey; query != nil {
if err := _q.loadAPIKey(ctx, query, nodes, nil,
func(n *UsageLog, e *ApiKey) { n.Edges.APIKey = e }); err != nil {
return nil, err
}
}
if query := _q.withAccount; query != nil {
if err := _q.loadAccount(ctx, query, nodes, nil,
func(n *UsageLog, e *Account) { n.Edges.Account = e }); err != nil {
return nil, err
}
}
if query := _q.withGroup; query != nil {
if err := _q.loadGroup(ctx, query, nodes, nil,
func(n *UsageLog, e *Group) { n.Edges.Group = e }); err != nil {
return nil, err
}
}
if query := _q.withSubscription; query != nil {
if err := _q.loadSubscription(ctx, query, nodes, nil,
func(n *UsageLog, e *UserSubscription) { n.Edges.Subscription = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UsageLogQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UsageLog)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UsageLogQuery) loadAPIKey(ctx context.Context, query *ApiKeyQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *ApiKey)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UsageLog)
for i := range nodes {
fk := nodes[i].APIKeyID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(apikey.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "api_key_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UsageLogQuery) loadAccount(ctx context.Context, query *AccountQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *Account)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UsageLog)
for i := range nodes {
fk := nodes[i].AccountID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(account.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "account_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UsageLogQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *Group)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UsageLog)
for i := range nodes {
if nodes[i].GroupID == nil {
continue
}
fk := *nodes[i].GroupID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(group.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UsageLogQuery) loadSubscription(ctx context.Context, query *UserSubscriptionQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *UserSubscription)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UsageLog)
for i := range nodes {
if nodes[i].SubscriptionID == nil {
continue
}
fk := *nodes[i].SubscriptionID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(usersubscription.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "subscription_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UsageLogQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UsageLogQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usagelog.FieldID)
for i := range fields {
if fields[i] != usagelog.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(usagelog.FieldUserID)
}
if _q.withAPIKey != nil {
_spec.Node.AddColumnOnce(usagelog.FieldAPIKeyID)
}
if _q.withAccount != nil {
_spec.Node.AddColumnOnce(usagelog.FieldAccountID)
}
if _q.withGroup != nil {
_spec.Node.AddColumnOnce(usagelog.FieldGroupID)
}
if _q.withSubscription != nil {
_spec.Node.AddColumnOnce(usagelog.FieldSubscriptionID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(usagelog.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = usagelog.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UsageLogGroupBy is the group-by builder for UsageLog entities.
type UsageLogGroupBy struct {
selector
build *UsageLogQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UsageLogGroupBy) Aggregate(fns ...AggregateFunc) *UsageLogGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UsageLogGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UsageLogQuery, *UsageLogGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UsageLogGroupBy) sqlScan(ctx context.Context, root *UsageLogQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UsageLogSelect is the builder for selecting fields of UsageLog entities.
type UsageLogSelect struct {
*UsageLogQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UsageLogSelect) Aggregate(fns ...AggregateFunc) *UsageLogSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UsageLogSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UsageLogQuery, *UsageLogSelect](ctx, _s.UsageLogQuery, _s, _s.inters, v)
}
func (_s *UsageLogSelect) sqlScan(ctx context.Context, root *UsageLogQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

File diff suppressed because it is too large Load Diff

View File

@@ -59,11 +59,13 @@ type UserEdges struct {
AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"`
// AllowedGroups holds the value of the allowed_groups edge.
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [6]bool
loadedTypes [7]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -111,10 +113,19 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
return nil, &NotLoadedError{edge: "allowed_groups"}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[5] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[5] {
if e.loadedTypes[6] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -265,6 +276,11 @@ func (_m *User) QueryAllowedGroups() *GroupQuery {
return NewUserClient(_m.config).QueryAllowedGroups(_m)
}
// QueryUsageLogs queries the "usage_logs" edge of the User entity.
func (_m *User) QueryUsageLogs() *UsageLogQuery {
return NewUserClient(_m.config).QueryUsageLogs(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)

View File

@@ -49,6 +49,8 @@ const (
EdgeAssignedSubscriptions = "assigned_subscriptions"
// EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations.
EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -86,6 +88,13 @@ const (
// AllowedGroupsInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package.
AllowedGroupsInverseTable = "groups"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -308,6 +317,20 @@ func ByAllowedGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -356,6 +379,13 @@ func newAllowedGroupsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2M, false, AllowedGroupsTable, AllowedGroupsPrimaryKey...),
)
}
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@@ -895,6 +895,29 @@ func HasAllowedGroupsWith(preds ...predicate.Group) predicate.User {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -253,6 +254,21 @@ func (_c *UserCreate) AddAllowedGroups(v ...*Group) *UserCreate {
return _c.AddAllowedGroupIDs(ids...)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *UserCreate) AddUsageLogIDs(ids ...int64) *UserCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -559,6 +575,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
edge.Target.Fields = specE.Fields
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -33,6 +34,7 @@ type UserQuery struct {
withSubscriptions *UserSubscriptionQuery
withAssignedSubscriptions *UserSubscriptionQuery
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withUserAllowedGroups *UserAllowedGroupQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
@@ -180,6 +182,28 @@ func (_q *UserQuery) QueryAllowedGroups() *GroupQuery {
return query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -399,6 +423,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withSubscriptions: _q.withSubscriptions.Clone(),
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -461,6 +486,17 @@ func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery {
return _q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -550,12 +586,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [6]bool{
loadedTypes = [7]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil,
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -614,6 +651,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *User) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *User, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -811,6 +855,36 @@ func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, n
}
return nil
}
func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*User, init func(*User), assign func(*User, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldUserID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -273,6 +274,21 @@ func (_u *UserUpdate) AddAllowedGroups(v ...*Group) *UserUpdate {
return _u.AddAllowedGroupIDs(ids...)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserUpdate) AddUsageLogIDs(ids ...int64) *UserUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -383,6 +399,27 @@ func (_u *UserUpdate) RemoveAllowedGroups(v ...*Group) *UserUpdate {
return _u.RemoveAllowedGroupIDs(ids...)
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdate) ClearUsageLogs() *UserUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserUpdate) RemoveUsageLogIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -751,6 +788,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
edge.Target.Fields = specE.Fields
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -1012,6 +1094,21 @@ func (_u *UserUpdateOne) AddAllowedGroups(v ...*Group) *UserUpdateOne {
return _u.AddAllowedGroupIDs(ids...)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserUpdateOne) AddUsageLogIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -1122,6 +1219,27 @@ func (_u *UserUpdateOne) RemoveAllowedGroups(v ...*Group) *UserUpdateOne {
return _u.RemoveAllowedGroupIDs(ids...)
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserUpdateOne) ClearUsageLogs() *UserUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -1520,6 +1638,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
edge.Target.Fields = specE.Fields
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.UsageLogsTable,
Columns: []string{user.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -23,6 +23,8 @@ type UserSubscription struct {
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// GroupID holds the value of the "group_id" field.
@@ -65,9 +67,11 @@ type UserSubscriptionEdges struct {
Group *Group `json:"group,omitempty"`
// AssignedByUser holds the value of the assigned_by_user edge.
AssignedByUser *User `json:"assigned_by_user,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [3]bool
loadedTypes [4]bool
}
// UserOrErr returns the User value or an error if the edge
@@ -103,6 +107,15 @@ func (e UserSubscriptionEdges) AssignedByUserOrErr() (*User, error) {
return nil, &NotLoadedError{edge: "assigned_by_user"}
}
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserSubscriptionEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[3] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserSubscription) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
@@ -114,7 +127,7 @@ func (*UserSubscription) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case usersubscription.FieldStatus, usersubscription.FieldNotes:
values[i] = new(sql.NullString)
case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt:
case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldDeletedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -149,6 +162,13 @@ func (_m *UserSubscription) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case usersubscription.FieldDeletedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
} else if value.Valid {
_m.DeletedAt = new(time.Time)
*_m.DeletedAt = value.Time
}
case usersubscription.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
@@ -266,6 +286,11 @@ func (_m *UserSubscription) QueryAssignedByUser() *UserQuery {
return NewUserSubscriptionClient(_m.config).QueryAssignedByUser(_m)
}
// QueryUsageLogs queries the "usage_logs" edge of the UserSubscription entity.
func (_m *UserSubscription) QueryUsageLogs() *UsageLogQuery {
return NewUserSubscriptionClient(_m.config).QueryUsageLogs(_m)
}
// Update returns a builder for updating this UserSubscription.
// Note that you need to call UserSubscription.Unwrap() before calling this method if this UserSubscription
// was returned from a transaction, and the transaction was committed or rolled back.
@@ -295,6 +320,11 @@ func (_m *UserSubscription) String() string {
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.DeletedAt; v != nil {
builder.WriteString("deleted_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")

View File

@@ -5,6 +5,7 @@ package usersubscription
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
@@ -18,6 +19,8 @@ const (
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldGroupID holds the string denoting the group_id field in the database.
@@ -52,6 +55,8 @@ const (
EdgeGroup = "group"
// EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations.
EdgeAssignedByUser = "assigned_by_user"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// Table holds the table name of the usersubscription in the database.
Table = "user_subscriptions"
// UserTable is the table that holds the user relation/edge.
@@ -75,6 +80,13 @@ const (
AssignedByUserInverseTable = "users"
// AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge.
AssignedByUserColumn = "assigned_by"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "subscription_id"
)
// Columns holds all SQL columns for usersubscription fields.
@@ -82,6 +94,7 @@ var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldDeletedAt,
FieldUserID,
FieldGroupID,
FieldStartsAt,
@@ -108,7 +121,14 @@ func ValidColumn(column string) bool {
return false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
@@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
@@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio
sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...))
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
@@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn),
)
}
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}

View File

@@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
@@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldDeletedAt))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
@@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserSubscription) predicate.UserSubscription {
return predicate.UserSubscription(sql.AndPredicates(predicates...))

View File

@@ -12,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr
return _c
}
// SetDeletedAt sets the "deleted_at" field.
func (_c *UserSubscriptionCreate) SetDeletedAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetDeletedAt(v)
return _c
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetDeletedAt(*v)
}
return _c
}
// SetUserID sets the "user_id" field.
func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate {
_c.mutation.SetUserID(v)
@@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr
return _c.SetAssignedByUserID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *UserSubscriptionCreate) AddUsageLogIDs(ids ...int64) *UserSubscriptionCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *UserSubscriptionCreate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
return _c.mutation
@@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
// Save creates the UserSubscription in the database.
func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) {
_c.defaults()
if err := _c.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
@@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func (_c *UserSubscriptionCreate) defaults() {
func (_c *UserSubscriptionCreate) defaults() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
if usersubscription.DefaultCreatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultCreatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
if usersubscription.DefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
@@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() {
_c.mutation.SetMonthlyUsageUsd(v)
}
if _, ok := _c.mutation.AssignedAt(); !ok {
if usersubscription.DefaultAssignedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultAssignedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultAssignedAt()
_c.mutation.SetAssignedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
_node.DeletedAt = &value
}
if value, ok := _c.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
_node.StartsAt = value
@@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_node.AssignedBy = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
@@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert {
return u
}
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsert) SetDeletedAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldDeletedAt, v)
return u
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateDeletedAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldDeletedAt)
return u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsert) ClearDeletedAt() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldDeletedAt)
return u
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldUserID, v)
@@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne
})
}
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsertOne) SetDeletedAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDeletedAt(v)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateDeletedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDeletedAt()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsertOne) ClearDeletedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDeletedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
@@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu
})
}
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsertBulk) SetDeletedAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDeletedAt(v)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateDeletedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDeletedAt()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsertBulk) ClearDeletedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDeletedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {

View File

@@ -4,6 +4,7 @@ package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
@@ -13,6 +14,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -27,6 +29,7 @@ type UserSubscriptionQuery struct {
withUser *UserQuery
withGroup *GroupQuery
withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery {
return query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *UserSubscriptionQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserSubscription entity from the query.
// Returns a *NotFoundError when no UserSubscription was found.
func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) {
@@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery {
withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(),
withAssignedByUser: _q.withAssignedByUser.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
@@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U
return _q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserSubscriptionQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
var (
nodes = []*UserSubscription{}
_spec = _q.querySpec()
loadedTypes = [3]bool{
loadedTypes = [4]bool{
_q.withUser != nil,
_q.withGroup != nil,
_q.withAssignedByUser != nil,
_q.withUsageLogs != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
@@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
return nil, err
}
}
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *UserSubscription) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *UserSubscription, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
@@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query *
}
return nil
}
func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*UserSubscription)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldSubscriptionID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(usersubscription.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.SubscriptionID
if fk == nil {
return fmt.Errorf(`foreign-key "subscription_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "subscription_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()

View File

@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserSubscriptionUpdate) SetDeletedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserSubscriptionUpdate) ClearDeletedAt() *UserSubscriptionUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetUserID(v)
@@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp
return _u.SetAssignedByUserID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserSubscriptionUpdate) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation {
return _u.mutation
@@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdate) ClearUsageLogs() *UserSubscriptionUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserSubscriptionUpdate) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserSubscriptionUpdate) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
@@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdate) defaults() {
func (_u *UserSubscriptionUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if usersubscription.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
@@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label}
@@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserSubscriptionUpdateOne) SetDeletedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserSubscriptionUpdateOne) ClearDeletedAt() *UserSubscriptionUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetUserID(v)
@@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio
return _u.SetAssignedByUserID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserSubscriptionUpdateOne) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdateOne) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation {
return _u.mutation
@@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdateOne) ClearUsageLogs() *UserSubscriptionUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserSubscriptionUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserSubscriptionUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne {
_u.mutation.Where(ps...)
@@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use
// Save executes the query and returns the updated UserSubscription entity.
func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) {
_u.defaults()
if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
@@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdateOne) defaults() {
func (_u *UserSubscriptionUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if usersubscription.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
@@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserSubscription{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -3,6 +3,7 @@ package config
import (
"fmt"
"strings"
"time"
"github.com/spf13/viper"
)
@@ -12,6 +13,20 @@ const (
RunModeSimple = "simple"
)
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
// ConnectionPoolIsolationProxy: 按代理隔离
// 同一代理地址共享连接池,适合代理数量少、账户数量多的场景
ConnectionPoolIsolationProxy = "proxy"
// ConnectionPoolIsolationAccount: 按账户隔离
// 每个账户独立连接池,适合账户数量少、需要严格隔离的场景
ConnectionPoolIsolationAccount = "account"
// ConnectionPoolIsolationAccountProxy: 按账户+代理组合隔离(默认)
// 同一账户+代理组合共享连接池,提供最细粒度的隔离
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
@@ -29,6 +44,7 @@ type Config struct {
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
}
type GeminiOAuthConfig struct {
@@ -37,6 +53,17 @@ type GeminiOAuthConfig struct {
Scopes string `mapstructure:"scopes"`
}
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -79,12 +106,71 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略proxy/account/account_proxy
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
MaxIdleConns int `mapstructure:"max_idle_conns"`
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲0表示无限制
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
// MaxUpstreamClients: 上游连接池客户端最大缓存数量
// 当使用连接池隔离策略时,系统会为不同的账户/代理组合创建独立的 HTTP 客户端
// 此参数限制缓存的客户端数量,超出后会淘汰最久未使用的客户端
// 建议值:预估的活跃账户数 * 1.2(留有余量)
MaxUpstreamClients int `mapstructure:"max_upstream_clients"`
// ClientIdleTTLSeconds: 上游连接池客户端空闲回收阈值(秒)
// 超过此时间未使用的客户端会被标记为可回收
// 建议值:根据用户访问频率设置,一般 10-30 分钟
ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"`
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type GatewaySchedulingConfig struct {
// 粘性会话排队配置
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
}
func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
@@ -92,6 +178,15 @@ type DatabaseConfig struct {
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
MaxOpenConns int `mapstructure:"max_open_conns"`
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
MaxIdleConns int `mapstructure:"max_idle_conns"`
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"`
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"`
}
func (d *DatabaseConfig) DSN() string {
@@ -112,11 +207,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
)
}
// RedisConfig Redis 连接配置
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
// 连接池与超时配置(性能优化:可配置化连接池参数)
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
// PoolSize: 连接池大小,控制最大并发连接数
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
}
func (r *RedisConfig) Address() string {
@@ -203,12 +311,21 @@ func setDefaults() {
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
// Redis
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
@@ -240,6 +357,26 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
@@ -254,6 +391,7 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
}
func (c *Config) Validate() error {
@@ -263,6 +401,86 @@ func (c *Config) Validate() error {
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
}
if c.Database.MaxIdleConns < 0 {
return fmt.Errorf("database.max_idle_conns must be non-negative")
}
if c.Database.MaxIdleConns > c.Database.MaxOpenConns {
return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns")
}
if c.Database.ConnMaxLifetimeMinutes < 0 {
return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative")
}
if c.Database.ConnMaxIdleTimeMinutes < 0 {
return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative")
}
if c.Redis.DialTimeoutSeconds <= 0 {
return fmt.Errorf("redis.dial_timeout_seconds must be positive")
}
if c.Redis.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("redis.read_timeout_seconds must be positive")
}
if c.Redis.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("redis.write_timeout_seconds must be positive")
}
if c.Redis.PoolSize <= 0 {
return fmt.Errorf("redis.pool_size must be positive")
}
if c.Redis.MinIdleConns < 0 {
return fmt.Errorf("redis.min_idle_conns must be non-negative")
}
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
default:
return fmt.Errorf("gateway.connection_pool_isolation must be one of: %s/%s/%s",
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
}
}
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
if c.Gateway.MaxIdleConnsPerHost <= 0 {
return fmt.Errorf("gateway.max_idle_conns_per_host must be positive")
}
if c.Gateway.MaxConnsPerHost < 0 {
return fmt.Errorf("gateway.max_conns_per_host must be non-negative")
}
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
}
if c.Gateway.ClientIdleTTLSeconds <= 0 {
return fmt.Errorf("gateway.client_idle_ttl_seconds must be positive")
}
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
return nil
}

View File

@@ -1,6 +1,11 @@
package config
import "testing"
import (
"testing"
"time"
"github.com/spf13/viper"
)
func TestNormalizeRunMode(t *testing.T) {
tests := []struct {
@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
}
}
}
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
}
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}

View File

@@ -10,15 +10,17 @@ import (
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
}
}
@@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587
}
// Turnstile 参数验证
if req.TurnstileEnabled {
// 检查必填字段
if req.TurnstileSiteKey == "" {
response.BadRequest(c, "Turnstile Site Key is required when enabled")
return
}
if req.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
// 获取当前设置,检查参数是否有变化
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err)
return
}
}
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,

View File

@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
@@ -76,15 +80,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 解析请求获取模型名和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -124,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
platform := ""
@@ -133,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
@@ -141,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -150,35 +162,77 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, req.Model)
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -223,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -232,23 +286,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, req.Model)
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
@@ -256,11 +349,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -525,6 +621,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
@@ -534,15 +634,18 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware2.GetSubscriptionFromContext(c)
@@ -554,17 +657,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return

View File

@@ -3,6 +3,7 @@ package handler
import (
"context"
"fmt"
"math/rand"
"net/http"
"time"
@@ -11,11 +12,28 @@ import (
"github.com/gin-gonic/gin"
)
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔100ms轮询并发槽位存在以下问题
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier = 1.5
// maxBackoff 最大退避时间
maxBackoff = 2 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
@@ -65,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
@@ -108,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
@@ -131,8 +164,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
pingCh = pingTicker.C
}
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
select {
@@ -156,7 +191,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
flusher.Flush()
case <-pollTicker.C:
case <-timer.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
@@ -174,6 +209,40 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff, rng)
timer.Reset(backoff)
}
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil此时不添加抖动
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff
}
if jittered > maxBackoff {
return maxBackoff
}
return jittered
}

View File

@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
return
}
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
@@ -191,14 +195,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -207,12 +216,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
account := selection.Account
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流)
@@ -225,6 +270,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {

View File

@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Read request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
@@ -76,6 +80,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
@@ -136,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
@@ -146,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// Forward request
@@ -161,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {

View File

@@ -0,0 +1,27 @@
package handler
import (
"errors"
"fmt"
"net/http"
)
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
return maxErr, true
}
return nil, false
}
func formatBodyLimit(limit int64) string {
const mb = 1024 * 1024
if limit >= mb {
return fmt.Sprintf("%dMB", limit/mb)
}
return fmt.Sprintf("%dB", limit)
}
func buildBodyTooLargeMessage(limit int64) string {
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
}

View File

@@ -0,0 +1,45 @@
package handler
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRequestBodyLimitTooLarge(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := int64(16)
router := gin.New()
router.Use(middleware.RequestBodyLimit(limit))
router.POST("/test", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": buildBodyTooLargeMessage(maxErr.Limit),
})
return
}
c.JSON(http.StatusBadRequest, gin.H{
"error": "read_failed",
})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
payload := bytes.Repeat([]byte("a"), int(limit+1))
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
}

View File

@@ -1,16 +0,0 @@
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}

View File

@@ -1,79 +0,0 @@
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite",
"gemini-3-flash",
"gemini-3-pro-low",
"gemini-3-pro-high",
}
func TestMain(m *testing.M) {
@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
"gemini-3-flash", // 直接支持
"gemini-3-pro-low", // 直接支持
"gemini-3-pro-high", // 直接支持
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
}
for i, model := range geminiViaClaude {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
}
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
}
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
"claude-sonnet-4-5",
"claude-opus-4-5-thinking",
}
for i, model := range claudeViaGemini {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}

View File

@@ -37,12 +37,26 @@ type ClaudeMetadata struct {
}
// ClaudeTool Claude 工具定义
// 支持两种格式:
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
type ClaudeTool struct {
Name string `json:"name"`
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
Name string `json:"name"`
Description string `json:"description,omitempty"` // 标准格式使用
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
}
// CustomToolSpec MCP custom 工具规格
type CustomToolSpec struct {
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`

View File

@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
"github.com/google/uuid"
@@ -13,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking除非未来支持完整签名链路
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
@@ -30,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
@@ -147,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
Text: "Thinking...",
Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...)
}
}
@@ -170,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
@@ -198,15 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case "thinking":
part := GeminiPart{
Text: block.Thinking,
Thought: true,
if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
// Claude 模型:仅在提供有效 signature 时保留 thinking block否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
parts = append(parts, part)
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image":
if block.Source != nil && block.Source.Type == "base64" {
@@ -231,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID,
},
}
// 保留原有 signature或对 Gemini 模型使用 dummy signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature避免验证问题
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -378,13 +433,53 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
for i, tool := range tools {
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
continue
}
var description string
var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" {
if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
inputSchema = tool.InputSchema
}
// 清理 JSON Schema
params := cleanJSONSchema(tool.InputSchema)
params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"properties": map[string]any{},
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: tool.Description,
Description: description,
Parameters: params,
})
}
@@ -443,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
"$schema": true,
"$id": true,
"$ref": true,
"additionalProperties": true,
"minLength": true,
"maxLength": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"pattern": true,
"format": true,
"default": true,
"strict": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
@@ -487,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}

View File

@@ -0,0 +1,179 @@
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}

View File

@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",

View File

@@ -0,0 +1,157 @@
// Package httpclient 提供共享 HTTP 客户端池
//
// 性能优化说明:
// 原实现在多个服务中重复创建 http.Client
// 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端
//
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
// 4. 支持严格代理模式(代理失败则返回错误)
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
)
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
}
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
if client, ok := cached.(*http.Client); ok {
return client, nil
}
}
client, err := buildClient(opts)
if err != nil {
if opts.ProxyStrict {
return nil, err
}
fallback := opts
fallback.ProxyURL = ""
client, _ = buildClient(fallback)
}
actual, _ := sharedClients.LoadOrStore(key, client)
if c, ok := actual.(*http.Client); ok {
return c, nil
}
return client, nil
}
func buildClient(opts Options) (*http.Client, error) {
transport, err := buildTransport(opts)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}, nil
}
func buildTransport(opts Options) (*http.Transport, error) {
// 使用自定义值或默认值
maxIdleConns := opts.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
transport.Proxy = http.ProxyURL(parsed)
case "socks5", "socks5h":
dialer, err := proxy.FromURL(parsed, proxy.Direct)
if err != nil {
return nil, err
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}

View File

@@ -4,7 +4,7 @@ import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
)

View File

@@ -9,7 +9,7 @@ import (
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
err: errors2.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
err: errors2.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) {
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
Message: errors2.UnknownMessage,
},
},
}

View File

@@ -14,6 +14,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"time"
@@ -56,7 +57,7 @@ func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accoun
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
return service.ErrAccountNilInput
}
builder := r.client.Account.Create().
@@ -98,7 +99,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
created, err := builder.Save(ctx)
if err != nil {
return err
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.ID = created.ID
@@ -231,11 +232,32 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
@@ -393,25 +415,49 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
if len(groupIDs) == 0 {
if tx != nil {
return tx.Commit()
}
return nil
}
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
builders = append(builders, r.client.AccountGroup.Create().
builders = append(builders, txClient.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
@@ -555,24 +601,30 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return nil
}
accountExtra, err := r.client.Account.Query().
Where(dbaccount.IDEQ(id)).
Select(dbaccount.FieldExtra).
Only(ctx)
// 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题
payload, err := json.Marshal(updates)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
return err
}
extra := normalizeJSONMap(accountExtra.Extra)
for k, v := range updates {
extra[k] = v
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
)
if err != nil {
return err
}
_, err = r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {

View File

@@ -311,19 +311,20 @@ func groupEntityToService(g *dbent.Group) *service.Group {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
DefaultValidityDays: g.DefaultValidityDays,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}

View File

@@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
Impersonate: true,
})
}
func prefix(s string, n int) string {

View File

@@ -6,9 +6,9 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)

View File

@@ -2,68 +2,95 @@ package repository
import (
"context"
"errors"
"fmt"
"time"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 并发控制缓存常量定义
//
// 性能优化说明:
// 原实现使用 SCAN 命令遍历独立的槽位键concurrency:account:{id}:{requestID}
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
//
// 新实现改用 Redis 有序集合Sorted Set
// 1. 每个账号/用户只有一个键,成员为 requestID分数为时间戳
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
// 4. 单次 Redis 调用完成计数,减少网络往返
const (
// Key prefixes for independent slot keys
// Format: concurrency:account:{accountID}:{requestID}
// 并发槽位键前缀(有序集合)
// 格式: concurrency:account:{accountID}
accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID}
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
)
var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// acquireScript 使用有序集合计数并在未达上限时添加槽位
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
// ARGV[2] = TTL(秒)
// ARGV[3] = requestID
acquireScript = redis.NewScript(`
local pattern = KEYS[1]
local slotKey = KEYS[2]
local key = KEYS[1]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local requestID = ARGV[3]
-- Count existing slots using SCAN
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Check if we can acquire a slot
-- 清理过期槽位
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查是否已存在(支持重试场景刷新时间戳)
local exists = redis.call('ZSCORE', key, requestID)
if exists ~= false then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
-- 检查是否达到并发上限
local count = redis.call('ZCARD', key)
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
return 0
`)
// getCountScript counts slots using SCAN
// KEYS[1] = pattern for SCAN
// getCountScript 统计有序集合中的槽位数量并清理过期条目
// 使用 Redis TIME 命令获取服务器时间
// KEYS[1] = 有序集合键
// ARGV[1] = TTL
getCountScript = redis.NewScript(`
local pattern = KEYS[1]
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
-- 使用 Redis 服务器时间
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
@@ -89,55 +116,138 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
return 1
`)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
)
type concurrencyCache struct {
rdb *redis.Client
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
}
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
// waitQueueTTLSeconds: 等待队列过期时间0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
}
}
// Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
func accountSlotKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
func userSlotKey(userID int64) string {
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
@@ -145,13 +255,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
key := accountSlotKey(accountID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
pattern := accountSlotPattern(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
@@ -161,10 +272,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
@@ -172,13 +282,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
key := userSlotKey(userID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
pattern := userSlotPattern(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
@@ -189,7 +300,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
if err != nil {
return false, err
}
@@ -201,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}

Some files were not shown because too many files have changed in this diff Show More