mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Merge pull request #1455 from touwaeriol/feat/channel-management
feat(channel): add channel management with multi-mode pricing and billing integration
This commit is contained in:
@@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
|
||||
emailCache := repository.NewEmailCache(redisClient)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
@@ -138,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
@@ -175,9 +176,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService)
|
||||
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -213,7 +216,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
|
||||
@@ -744,6 +744,10 @@ var (
|
||||
{Name: "model", Type: field.TypeString, Size: 100},
|
||||
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "channel_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500},
|
||||
{Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50},
|
||||
{Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20},
|
||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
||||
@@ -783,31 +787,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -816,32 +820,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@@ -861,17 +865,17 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -19725,6 +19725,11 @@ type UsageLogMutation struct {
|
||||
model *string
|
||||
requested_model *string
|
||||
upstream_model *string
|
||||
channel_id *int64
|
||||
addchannel_id *int64
|
||||
model_mapping_chain *string
|
||||
billing_tier *string
|
||||
billing_mode *string
|
||||
input_tokens *int
|
||||
addinput_tokens *int
|
||||
output_tokens *int
|
||||
@@ -20160,6 +20165,223 @@ func (m *UsageLogMutation) ResetUpstreamModel() {
|
||||
delete(m.clearedFields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (m *UsageLogMutation) SetChannelID(i int64) {
|
||||
m.channel_id = &i
|
||||
m.addchannel_id = nil
|
||||
}
|
||||
|
||||
// ChannelID returns the value of the "channel_id" field in the mutation.
|
||||
func (m *UsageLogMutation) ChannelID() (r int64, exists bool) {
|
||||
v := m.channel_id
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldChannelID returns the old "channel_id" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldChannelID(ctx context.Context) (v *int64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldChannelID is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldChannelID requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldChannelID: %w", err)
|
||||
}
|
||||
return oldValue.ChannelID, nil
|
||||
}
|
||||
|
||||
// AddChannelID adds i to the "channel_id" field.
|
||||
func (m *UsageLogMutation) AddChannelID(i int64) {
|
||||
if m.addchannel_id != nil {
|
||||
*m.addchannel_id += i
|
||||
} else {
|
||||
m.addchannel_id = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedChannelID returns the value that was added to the "channel_id" field in this mutation.
|
||||
func (m *UsageLogMutation) AddedChannelID() (r int64, exists bool) {
|
||||
v := m.addchannel_id
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (m *UsageLogMutation) ClearChannelID() {
|
||||
m.channel_id = nil
|
||||
m.addchannel_id = nil
|
||||
m.clearedFields[usagelog.FieldChannelID] = struct{}{}
|
||||
}
|
||||
|
||||
// ChannelIDCleared returns if the "channel_id" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ChannelIDCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldChannelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetChannelID resets all changes to the "channel_id" field.
|
||||
func (m *UsageLogMutation) ResetChannelID() {
|
||||
m.channel_id = nil
|
||||
m.addchannel_id = nil
|
||||
delete(m.clearedFields, usagelog.FieldChannelID)
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (m *UsageLogMutation) SetModelMappingChain(s string) {
|
||||
m.model_mapping_chain = &s
|
||||
}
|
||||
|
||||
// ModelMappingChain returns the value of the "model_mapping_chain" field in the mutation.
|
||||
func (m *UsageLogMutation) ModelMappingChain() (r string, exists bool) {
|
||||
v := m.model_mapping_chain
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldModelMappingChain returns the old "model_mapping_chain" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldModelMappingChain(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldModelMappingChain is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldModelMappingChain requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldModelMappingChain: %w", err)
|
||||
}
|
||||
return oldValue.ModelMappingChain, nil
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (m *UsageLogMutation) ClearModelMappingChain() {
|
||||
m.model_mapping_chain = nil
|
||||
m.clearedFields[usagelog.FieldModelMappingChain] = struct{}{}
|
||||
}
|
||||
|
||||
// ModelMappingChainCleared returns if the "model_mapping_chain" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ModelMappingChainCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldModelMappingChain]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetModelMappingChain resets all changes to the "model_mapping_chain" field.
|
||||
func (m *UsageLogMutation) ResetModelMappingChain() {
|
||||
m.model_mapping_chain = nil
|
||||
delete(m.clearedFields, usagelog.FieldModelMappingChain)
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (m *UsageLogMutation) SetBillingTier(s string) {
|
||||
m.billing_tier = &s
|
||||
}
|
||||
|
||||
// BillingTier returns the value of the "billing_tier" field in the mutation.
|
||||
func (m *UsageLogMutation) BillingTier() (r string, exists bool) {
|
||||
v := m.billing_tier
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldBillingTier returns the old "billing_tier" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldBillingTier(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldBillingTier is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldBillingTier requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldBillingTier: %w", err)
|
||||
}
|
||||
return oldValue.BillingTier, nil
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (m *UsageLogMutation) ClearBillingTier() {
|
||||
m.billing_tier = nil
|
||||
m.clearedFields[usagelog.FieldBillingTier] = struct{}{}
|
||||
}
|
||||
|
||||
// BillingTierCleared returns if the "billing_tier" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) BillingTierCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldBillingTier]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetBillingTier resets all changes to the "billing_tier" field.
|
||||
func (m *UsageLogMutation) ResetBillingTier() {
|
||||
m.billing_tier = nil
|
||||
delete(m.clearedFields, usagelog.FieldBillingTier)
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (m *UsageLogMutation) SetBillingMode(s string) {
|
||||
m.billing_mode = &s
|
||||
}
|
||||
|
||||
// BillingMode returns the value of the "billing_mode" field in the mutation.
|
||||
func (m *UsageLogMutation) BillingMode() (r string, exists bool) {
|
||||
v := m.billing_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldBillingMode returns the old "billing_mode" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldBillingMode(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldBillingMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldBillingMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldBillingMode: %w", err)
|
||||
}
|
||||
return oldValue.BillingMode, nil
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (m *UsageLogMutation) ClearBillingMode() {
|
||||
m.billing_mode = nil
|
||||
m.clearedFields[usagelog.FieldBillingMode] = struct{}{}
|
||||
}
|
||||
|
||||
// BillingModeCleared returns if the "billing_mode" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) BillingModeCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldBillingMode]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetBillingMode resets all changes to the "billing_mode" field.
|
||||
func (m *UsageLogMutation) ResetBillingMode() {
|
||||
m.billing_mode = nil
|
||||
delete(m.clearedFields, usagelog.FieldBillingMode)
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (m *UsageLogMutation) SetGroupID(i int64) {
|
||||
m.group = &i
|
||||
@@ -21781,7 +22003,7 @@ func (m *UsageLogMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UsageLogMutation) Fields() []string {
|
||||
fields := make([]string, 0, 34)
|
||||
fields := make([]string, 0, 38)
|
||||
if m.user != nil {
|
||||
fields = append(fields, usagelog.FieldUserID)
|
||||
}
|
||||
@@ -21803,6 +22025,18 @@ func (m *UsageLogMutation) Fields() []string {
|
||||
if m.upstream_model != nil {
|
||||
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
if m.channel_id != nil {
|
||||
fields = append(fields, usagelog.FieldChannelID)
|
||||
}
|
||||
if m.model_mapping_chain != nil {
|
||||
fields = append(fields, usagelog.FieldModelMappingChain)
|
||||
}
|
||||
if m.billing_tier != nil {
|
||||
fields = append(fields, usagelog.FieldBillingTier)
|
||||
}
|
||||
if m.billing_mode != nil {
|
||||
fields = append(fields, usagelog.FieldBillingMode)
|
||||
}
|
||||
if m.group != nil {
|
||||
fields = append(fields, usagelog.FieldGroupID)
|
||||
}
|
||||
@@ -21906,6 +22140,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.RequestedModel()
|
||||
case usagelog.FieldUpstreamModel:
|
||||
return m.UpstreamModel()
|
||||
case usagelog.FieldChannelID:
|
||||
return m.ChannelID()
|
||||
case usagelog.FieldModelMappingChain:
|
||||
return m.ModelMappingChain()
|
||||
case usagelog.FieldBillingTier:
|
||||
return m.BillingTier()
|
||||
case usagelog.FieldBillingMode:
|
||||
return m.BillingMode()
|
||||
case usagelog.FieldGroupID:
|
||||
return m.GroupID()
|
||||
case usagelog.FieldSubscriptionID:
|
||||
@@ -21983,6 +22225,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
||||
return m.OldRequestedModel(ctx)
|
||||
case usagelog.FieldUpstreamModel:
|
||||
return m.OldUpstreamModel(ctx)
|
||||
case usagelog.FieldChannelID:
|
||||
return m.OldChannelID(ctx)
|
||||
case usagelog.FieldModelMappingChain:
|
||||
return m.OldModelMappingChain(ctx)
|
||||
case usagelog.FieldBillingTier:
|
||||
return m.OldBillingTier(ctx)
|
||||
case usagelog.FieldBillingMode:
|
||||
return m.OldBillingMode(ctx)
|
||||
case usagelog.FieldGroupID:
|
||||
return m.OldGroupID(ctx)
|
||||
case usagelog.FieldSubscriptionID:
|
||||
@@ -22095,6 +22345,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetUpstreamModel(v)
|
||||
return nil
|
||||
case usagelog.FieldChannelID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetChannelID(v)
|
||||
return nil
|
||||
case usagelog.FieldModelMappingChain:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetModelMappingChain(v)
|
||||
return nil
|
||||
case usagelog.FieldBillingTier:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetBillingTier(v)
|
||||
return nil
|
||||
case usagelog.FieldBillingMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetBillingMode(v)
|
||||
return nil
|
||||
case usagelog.FieldGroupID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
@@ -22292,6 +22570,9 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
// this mutation.
|
||||
func (m *UsageLogMutation) AddedFields() []string {
|
||||
var fields []string
|
||||
if m.addchannel_id != nil {
|
||||
fields = append(fields, usagelog.FieldChannelID)
|
||||
}
|
||||
if m.addinput_tokens != nil {
|
||||
fields = append(fields, usagelog.FieldInputTokens)
|
||||
}
|
||||
@@ -22354,6 +22635,8 @@ func (m *UsageLogMutation) AddedFields() []string {
|
||||
// was not set, or was not defined in the schema.
|
||||
func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
|
||||
switch name {
|
||||
case usagelog.FieldChannelID:
|
||||
return m.AddedChannelID()
|
||||
case usagelog.FieldInputTokens:
|
||||
return m.AddedInputTokens()
|
||||
case usagelog.FieldOutputTokens:
|
||||
@@ -22399,6 +22682,13 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
|
||||
// type.
|
||||
func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
||||
switch name {
|
||||
case usagelog.FieldChannelID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddChannelID(v)
|
||||
return nil
|
||||
case usagelog.FieldInputTokens:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -22539,6 +22829,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(usagelog.FieldUpstreamModel) {
|
||||
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldChannelID) {
|
||||
fields = append(fields, usagelog.FieldChannelID)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldModelMappingChain) {
|
||||
fields = append(fields, usagelog.FieldModelMappingChain)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldBillingTier) {
|
||||
fields = append(fields, usagelog.FieldBillingTier)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldBillingMode) {
|
||||
fields = append(fields, usagelog.FieldBillingMode)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldGroupID) {
|
||||
fields = append(fields, usagelog.FieldGroupID)
|
||||
}
|
||||
@@ -22586,6 +22888,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
||||
case usagelog.FieldUpstreamModel:
|
||||
m.ClearUpstreamModel()
|
||||
return nil
|
||||
case usagelog.FieldChannelID:
|
||||
m.ClearChannelID()
|
||||
return nil
|
||||
case usagelog.FieldModelMappingChain:
|
||||
m.ClearModelMappingChain()
|
||||
return nil
|
||||
case usagelog.FieldBillingTier:
|
||||
m.ClearBillingTier()
|
||||
return nil
|
||||
case usagelog.FieldBillingMode:
|
||||
m.ClearBillingMode()
|
||||
return nil
|
||||
case usagelog.FieldGroupID:
|
||||
m.ClearGroupID()
|
||||
return nil
|
||||
@@ -22642,6 +22956,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
||||
case usagelog.FieldUpstreamModel:
|
||||
m.ResetUpstreamModel()
|
||||
return nil
|
||||
case usagelog.FieldChannelID:
|
||||
m.ResetChannelID()
|
||||
return nil
|
||||
case usagelog.FieldModelMappingChain:
|
||||
m.ResetModelMappingChain()
|
||||
return nil
|
||||
case usagelog.FieldBillingTier:
|
||||
m.ResetBillingTier()
|
||||
return nil
|
||||
case usagelog.FieldBillingMode:
|
||||
m.ResetBillingMode()
|
||||
return nil
|
||||
case usagelog.FieldGroupID:
|
||||
m.ResetGroupID()
|
||||
return nil
|
||||
|
||||
@@ -875,92 +875,104 @@ func init() {
|
||||
usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
|
||||
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||
// usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field.
|
||||
usagelogDescModelMappingChain := usagelogFields[8].Descriptor()
|
||||
// usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
|
||||
usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error)
|
||||
// usagelogDescBillingTier is the schema descriptor for billing_tier field.
|
||||
usagelogDescBillingTier := usagelogFields[9].Descriptor()
|
||||
// usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
|
||||
usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error)
|
||||
// usagelogDescBillingMode is the schema descriptor for billing_mode field.
|
||||
usagelogDescBillingMode := usagelogFields[10].Descriptor()
|
||||
// usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
|
||||
usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error)
|
||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||
usagelogDescInputTokens := usagelogFields[9].Descriptor()
|
||||
usagelogDescInputTokens := usagelogFields[13].Descriptor()
|
||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||
usagelogDescOutputTokens := usagelogFields[10].Descriptor()
|
||||
usagelogDescOutputTokens := usagelogFields[14].Descriptor()
|
||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||
usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor()
|
||||
usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor()
|
||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||
usagelogDescCacheReadTokens := usagelogFields[12].Descriptor()
|
||||
usagelogDescCacheReadTokens := usagelogFields[16].Descriptor()
|
||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor()
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor()
|
||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor()
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor()
|
||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||
usagelogDescInputCost := usagelogFields[15].Descriptor()
|
||||
usagelogDescInputCost := usagelogFields[19].Descriptor()
|
||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||
usagelogDescOutputCost := usagelogFields[16].Descriptor()
|
||||
usagelogDescOutputCost := usagelogFields[20].Descriptor()
|
||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||
usagelogDescCacheCreationCost := usagelogFields[17].Descriptor()
|
||||
usagelogDescCacheCreationCost := usagelogFields[21].Descriptor()
|
||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||
usagelogDescCacheReadCost := usagelogFields[18].Descriptor()
|
||||
usagelogDescCacheReadCost := usagelogFields[22].Descriptor()
|
||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||
usagelogDescTotalCost := usagelogFields[19].Descriptor()
|
||||
usagelogDescTotalCost := usagelogFields[23].Descriptor()
|
||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||
usagelogDescActualCost := usagelogFields[20].Descriptor()
|
||||
usagelogDescActualCost := usagelogFields[24].Descriptor()
|
||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
usagelogDescRateMultiplier := usagelogFields[21].Descriptor()
|
||||
usagelogDescRateMultiplier := usagelogFields[25].Descriptor()
|
||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||
usagelogDescBillingType := usagelogFields[23].Descriptor()
|
||||
usagelogDescBillingType := usagelogFields[27].Descriptor()
|
||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||
// usagelogDescStream is the schema descriptor for stream field.
|
||||
usagelogDescStream := usagelogFields[24].Descriptor()
|
||||
usagelogDescStream := usagelogFields[28].Descriptor()
|
||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||
usagelogDescUserAgent := usagelogFields[27].Descriptor()
|
||||
usagelogDescUserAgent := usagelogFields[31].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||
usagelogDescIPAddress := usagelogFields[28].Descriptor()
|
||||
usagelogDescIPAddress := usagelogFields[32].Descriptor()
|
||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[29].Descriptor()
|
||||
usagelogDescImageCount := usagelogFields[33].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[30].Descriptor()
|
||||
usagelogDescImageSize := usagelogFields[34].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||
usagelogDescMediaType := usagelogFields[31].Descriptor()
|
||||
usagelogDescMediaType := usagelogFields[35].Descriptor()
|
||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor()
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[33].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[37].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@@ -53,6 +53,10 @@ func (UsageLog) Fields() []ent.Field {
|
||||
MaxLen(100).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"),
|
||||
field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
|
||||
field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
|
||||
field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"),
|
||||
field.Int64("group_id").
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
@@ -36,6 +36,14 @@ type UsageLog struct {
|
||||
RequestedModel *string `json:"requested_model,omitempty"`
|
||||
// UpstreamModel holds the value of the "upstream_model" field.
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
// 渠道 ID
|
||||
ChannelID *int64 `json:"channel_id,omitempty"`
|
||||
// 模型映射链
|
||||
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
|
||||
// 计费层级标签
|
||||
BillingTier *string `json:"billing_tier,omitempty"`
|
||||
// 计费模式:token/per_request/image
|
||||
BillingMode *string `json:"billing_mode,omitempty"`
|
||||
// GroupID holds the value of the "group_id" field.
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
// SubscriptionID holds the value of the "subscription_id" field.
|
||||
@@ -177,9 +185,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -248,6 +256,34 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.UpstreamModel = new(string)
|
||||
*_m.UpstreamModel = value.String
|
||||
}
|
||||
case usagelog.FieldChannelID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field channel_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ChannelID = new(int64)
|
||||
*_m.ChannelID = value.Int64
|
||||
}
|
||||
case usagelog.FieldModelMappingChain:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ModelMappingChain = new(string)
|
||||
*_m.ModelMappingChain = value.String
|
||||
}
|
||||
case usagelog.FieldBillingTier:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field billing_tier", values[i])
|
||||
} else if value.Valid {
|
||||
_m.BillingTier = new(string)
|
||||
*_m.BillingTier = value.String
|
||||
}
|
||||
case usagelog.FieldBillingMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field billing_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.BillingMode = new(string)
|
||||
*_m.BillingMode = value.String
|
||||
}
|
||||
case usagelog.FieldGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||
@@ -505,6 +541,26 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ChannelID; v != nil {
|
||||
builder.WriteString("channel_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ModelMappingChain; v != nil {
|
||||
builder.WriteString("model_mapping_chain=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.BillingTier; v != nil {
|
||||
builder.WriteString("billing_tier=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.BillingMode; v != nil {
|
||||
builder.WriteString("billing_mode=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.GroupID; v != nil {
|
||||
builder.WriteString("group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
|
||||
@@ -28,6 +28,14 @@ const (
|
||||
FieldRequestedModel = "requested_model"
|
||||
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||
FieldUpstreamModel = "upstream_model"
|
||||
// FieldChannelID holds the string denoting the channel_id field in the database.
|
||||
FieldChannelID = "channel_id"
|
||||
// FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database.
|
||||
FieldModelMappingChain = "model_mapping_chain"
|
||||
// FieldBillingTier holds the string denoting the billing_tier field in the database.
|
||||
FieldBillingTier = "billing_tier"
|
||||
// FieldBillingMode holds the string denoting the billing_mode field in the database.
|
||||
FieldBillingMode = "billing_mode"
|
||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||
FieldGroupID = "group_id"
|
||||
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
||||
@@ -141,6 +149,10 @@ var Columns = []string{
|
||||
FieldModel,
|
||||
FieldRequestedModel,
|
||||
FieldUpstreamModel,
|
||||
FieldChannelID,
|
||||
FieldModelMappingChain,
|
||||
FieldBillingTier,
|
||||
FieldBillingMode,
|
||||
FieldGroupID,
|
||||
FieldSubscriptionID,
|
||||
FieldInputTokens,
|
||||
@@ -189,6 +201,12 @@ var (
|
||||
RequestedModelValidator func(string) error
|
||||
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
UpstreamModelValidator func(string) error
|
||||
// ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
|
||||
ModelMappingChainValidator func(string) error
|
||||
// BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
|
||||
BillingTierValidator func(string) error
|
||||
// BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
|
||||
BillingModeValidator func(string) error
|
||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||
DefaultInputTokens int
|
||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||
@@ -278,6 +296,26 @@ func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByChannelID orders the results by the channel_id field.
|
||||
func ByChannelID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldChannelID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByModelMappingChain orders the results by the model_mapping_chain field.
|
||||
func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBillingTier orders the results by the billing_tier field.
|
||||
func ByBillingTier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBillingTier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBillingMode orders the results by the billing_mode field.
|
||||
func ByBillingMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBillingMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByGroupID orders the results by the group_id field.
|
||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||
|
||||
@@ -90,6 +90,26 @@ func UpstreamModel(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ.
|
||||
func ChannelID(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ.
|
||||
func ModelMappingChain(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ.
|
||||
func BillingTier(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ.
|
||||
func BillingMode(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||
func GroupID(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||
@@ -565,6 +585,281 @@ func UpstreamModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// ChannelIDEQ applies the EQ predicate on the "channel_id" field.
|
||||
func ChannelIDEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field.
|
||||
func ChannelIDNEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDIn applies the In predicate on the "channel_id" field.
|
||||
func ChannelIDIn(vs ...int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...))
|
||||
}
|
||||
|
||||
// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field.
|
||||
func ChannelIDNotIn(vs ...int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...))
|
||||
}
|
||||
|
||||
// ChannelIDGT applies the GT predicate on the "channel_id" field.
|
||||
func ChannelIDGT(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDGTE applies the GTE predicate on the "channel_id" field.
|
||||
func ChannelIDGTE(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDLT applies the LT predicate on the "channel_id" field.
|
||||
func ChannelIDLT(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDLTE applies the LTE predicate on the "channel_id" field.
|
||||
func ChannelIDLTE(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field.
|
||||
func ChannelIDIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldChannelID))
|
||||
}
|
||||
|
||||
// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field.
|
||||
func ChannelIDNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldChannelID))
|
||||
}
|
||||
|
||||
// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...))
|
||||
}
|
||||
|
||||
// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...))
|
||||
}
|
||||
|
||||
// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain))
|
||||
}
|
||||
|
||||
// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain))
|
||||
}
|
||||
|
||||
// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// BillingTierEQ applies the EQ predicate on the "billing_tier" field.
|
||||
func BillingTierEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field.
|
||||
func BillingTierNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierIn applies the In predicate on the "billing_tier" field.
|
||||
func BillingTierIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...))
|
||||
}
|
||||
|
||||
// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field.
|
||||
func BillingTierNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...))
|
||||
}
|
||||
|
||||
// BillingTierGT applies the GT predicate on the "billing_tier" field.
|
||||
func BillingTierGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierGTE applies the GTE predicate on the "billing_tier" field.
|
||||
func BillingTierGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierLT applies the LT predicate on the "billing_tier" field.
|
||||
func BillingTierLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierLTE applies the LTE predicate on the "billing_tier" field.
|
||||
func BillingTierLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierContains applies the Contains predicate on the "billing_tier" field.
|
||||
func BillingTierContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field.
|
||||
func BillingTierHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field.
|
||||
func BillingTierHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field.
|
||||
func BillingTierIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier))
|
||||
}
|
||||
|
||||
// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field.
|
||||
func BillingTierNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier))
|
||||
}
|
||||
|
||||
// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field.
|
||||
func BillingTierEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field.
|
||||
func BillingTierContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingModeEQ applies the EQ predicate on the "billing_mode" field.
|
||||
func BillingModeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field.
|
||||
func BillingModeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeIn applies the In predicate on the "billing_mode" field.
|
||||
func BillingModeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...))
|
||||
}
|
||||
|
||||
// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field.
|
||||
func BillingModeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...))
|
||||
}
|
||||
|
||||
// BillingModeGT applies the GT predicate on the "billing_mode" field.
|
||||
func BillingModeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeGTE applies the GTE predicate on the "billing_mode" field.
|
||||
func BillingModeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeLT applies the LT predicate on the "billing_mode" field.
|
||||
func BillingModeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeLTE applies the LTE predicate on the "billing_mode" field.
|
||||
func BillingModeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeContains applies the Contains predicate on the "billing_mode" field.
|
||||
func BillingModeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field.
|
||||
func BillingModeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field.
|
||||
func BillingModeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field.
|
||||
func BillingModeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode))
|
||||
}
|
||||
|
||||
// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field.
|
||||
func BillingModeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode))
|
||||
}
|
||||
|
||||
// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field.
|
||||
func BillingModeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field.
|
||||
func BillingModeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||
func GroupIDEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||
|
||||
@@ -85,6 +85,62 @@ func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate {
|
||||
_c.mutation.SetChannelID(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetChannelID(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate {
|
||||
_c.mutation.SetModelMappingChain(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetModelMappingChain(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate {
|
||||
_c.mutation.SetBillingTier(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetBillingTier(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate {
|
||||
_c.mutation.SetBillingMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetBillingMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
||||
_c.mutation.SetGroupID(v)
|
||||
@@ -634,6 +690,21 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ModelMappingChain(); ok {
|
||||
if err := usagelog.ModelMappingChainValidator(v); err != nil {
|
||||
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.BillingTier(); ok {
|
||||
if err := usagelog.BillingTierValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.BillingMode(); ok {
|
||||
if err := usagelog.BillingModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.InputTokens(); !ok {
|
||||
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
||||
}
|
||||
@@ -760,6 +831,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
_node.UpstreamModel = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ChannelID(); ok {
|
||||
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
_node.ChannelID = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ModelMappingChain(); ok {
|
||||
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
|
||||
_node.ModelMappingChain = &value
|
||||
}
|
||||
if value, ok := _c.mutation.BillingTier(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
|
||||
_node.BillingTier = &value
|
||||
}
|
||||
if value, ok := _c.mutation.BillingMode(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
|
||||
_node.BillingMode = &value
|
||||
}
|
||||
if value, ok := _c.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
_node.InputTokens = value
|
||||
@@ -1093,6 +1180,84 @@ func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldChannelID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldChannelID)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddChannelID adds v to the "channel_id" field.
|
||||
func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert {
|
||||
u.Add(usagelog.FieldChannelID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldChannelID)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldModelMappingChain, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldModelMappingChain)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldModelMappingChain)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldBillingTier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldBillingTier)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldBillingTier)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldBillingMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldBillingMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldBillingMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldGroupID, v)
|
||||
@@ -1724,6 +1889,97 @@ func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddChannelID adds v to the "channel_id" field.
|
||||
func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetModelMappingChain(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingTier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2600,6 +2856,97 @@ func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddChannelID adds v to the "channel_id" field.
|
||||
func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetModelMappingChain(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingTier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@@ -142,6 +142,93 @@ func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.ResetChannelID()
|
||||
_u.mutation.SetChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetChannelID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddChannelID adds value to the "channel_id" field.
|
||||
func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.AddChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate {
|
||||
_u.mutation.ClearChannelID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetModelMappingChain(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetModelMappingChain(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate {
|
||||
_u.mutation.ClearModelMappingChain()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetBillingTier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetBillingTier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate {
|
||||
_u.mutation.ClearBillingTier()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetBillingMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetBillingMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate {
|
||||
_u.mutation.ClearBillingMode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@@ -795,6 +882,21 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
if err := usagelog.ModelMappingChainValidator(v); err != nil {
|
||||
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingTier(); ok {
|
||||
if err := usagelog.BillingTierValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingMode(); ok {
|
||||
if err := usagelog.BillingModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
@@ -857,6 +959,33 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.UpstreamModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ChannelID(); ok {
|
||||
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedChannelID(); ok {
|
||||
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.ChannelIDCleared() {
|
||||
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ModelMappingChainCleared() {
|
||||
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingTier(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingTierCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingMode(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingModeCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
}
|
||||
@@ -1279,6 +1408,93 @@ func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.ResetChannelID()
|
||||
_u.mutation.SetChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetChannelID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddChannelID adds value to the "channel_id" field.
|
||||
func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.AddChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearChannelID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetModelMappingChain(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetModelMappingChain(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearModelMappingChain()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetBillingTier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetBillingTier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearBillingTier()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetBillingMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetBillingMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearBillingMode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@@ -1945,6 +2161,21 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
if err := usagelog.ModelMappingChainValidator(v); err != nil {
|
||||
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingTier(); ok {
|
||||
if err := usagelog.BillingTierValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingMode(); ok {
|
||||
if err := usagelog.BillingModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
@@ -2024,6 +2255,33 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.UpstreamModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ChannelID(); ok {
|
||||
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedChannelID(); ok {
|
||||
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.ChannelIDCleared() {
|
||||
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ModelMappingChainCleared() {
|
||||
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingTier(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingTierCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingMode(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingModeCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
}
|
||||
|
||||
452
backend/internal/handler/admin/channel_handler.go
Normal file
452
backend/internal/handler/admin/channel_handler.go
Normal file
@@ -0,0 +1,452 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChannelHandler handles admin channel management
|
||||
type ChannelHandler struct {
|
||||
channelService *service.ChannelService
|
||||
billingService *service.BillingService
|
||||
}
|
||||
|
||||
// NewChannelHandler creates a new admin channel handler
|
||||
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService, billingService: billingService}
|
||||
}
|
||||
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
Platform string `json:"platform" binding:"omitempty,max=50"`
|
||||
Models []string `json:"models" binding:"required,min=1,max=100"`
|
||||
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
||||
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
||||
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
|
||||
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
|
||||
Intervals []pricingIntervalRequest `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalRequest struct {
|
||||
MinTokens int `json:"min_tokens"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
TierLabel string `json:"tier_label"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
Intervals []pricingIntervalResponse `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
MinTokens int `json:"min_tokens"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
TierLabel string `json:"tier_label,omitempty"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
if ch == nil {
|
||||
return nil
|
||||
}
|
||||
resp := &channelResponse{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
resp.BillingModelSource = ch.BillingModelSource
|
||||
if resp.BillingModelSource == "" {
|
||||
resp.BillingModelSource = service.BillingModelSourceChannelMapped
|
||||
}
|
||||
if resp.GroupIDs == nil {
|
||||
resp.GroupIDs = []int64{}
|
||||
}
|
||||
if resp.ModelMapping == nil {
|
||||
resp.ModelMapping = map[string]map[string]string{}
|
||||
}
|
||||
|
||||
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
|
||||
for _, p := range ch.ModelPricing {
|
||||
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse {
|
||||
models := p.Models
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
billingMode := string(p.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = string(service.BillingModeToken)
|
||||
}
|
||||
platform := p.Platform
|
||||
if platform == "" {
|
||||
platform = service.PlatformAnthropic
|
||||
}
|
||||
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
|
||||
for _, iv := range p.Intervals {
|
||||
intervals = append(intervals, intervalToResponse(iv))
|
||||
}
|
||||
return channelModelPricingResponse{
|
||||
ID: p.ID,
|
||||
Platform: platform,
|
||||
Models: models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: p.InputPrice,
|
||||
OutputPrice: p.OutputPrice,
|
||||
CacheWritePrice: p.CacheWritePrice,
|
||||
CacheReadPrice: p.CacheReadPrice,
|
||||
ImageOutputPrice: p.ImageOutputPrice,
|
||||
PerRequestPrice: p.PerRequestPrice,
|
||||
Intervals: intervals,
|
||||
}
|
||||
}
|
||||
|
||||
func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse {
|
||||
return pricingIntervalResponse{
|
||||
ID: iv.ID,
|
||||
MinTokens: iv.MinTokens,
|
||||
MaxTokens: iv.MaxTokens,
|
||||
TierLabel: iv.TierLabel,
|
||||
InputPrice: iv.InputPrice,
|
||||
OutputPrice: iv.OutputPrice,
|
||||
CacheWritePrice: iv.CacheWritePrice,
|
||||
CacheReadPrice: iv.CacheReadPrice,
|
||||
PerRequestPrice: iv.PerRequestPrice,
|
||||
SortOrder: iv.SortOrder,
|
||||
}
|
||||
}
|
||||
|
||||
func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
|
||||
result := make([]service.ChannelModelPricing, 0, len(reqs))
|
||||
for _, r := range reqs {
|
||||
billingMode := service.BillingMode(r.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := r.Platform
|
||||
if platform == "" {
|
||||
platform = service.PlatformAnthropic
|
||||
}
|
||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||
for _, iv := range r.Intervals {
|
||||
intervals = append(intervals, service.PricingInterval{
|
||||
MinTokens: iv.MinTokens,
|
||||
MaxTokens: iv.MaxTokens,
|
||||
TierLabel: iv.TierLabel,
|
||||
InputPrice: iv.InputPrice,
|
||||
OutputPrice: iv.OutputPrice,
|
||||
CacheWritePrice: iv.CacheWritePrice,
|
||||
CacheReadPrice: iv.CacheReadPrice,
|
||||
PerRequestPrice: iv.PerRequestPrice,
|
||||
SortOrder: iv.SortOrder,
|
||||
})
|
||||
}
|
||||
result = append(result, service.ChannelModelPricing{
|
||||
Platform: platform,
|
||||
Models: r.Models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: r.InputPrice,
|
||||
OutputPrice: r.OutputPrice,
|
||||
CacheWritePrice: r.CacheWritePrice,
|
||||
CacheReadPrice: r.CacheReadPrice,
|
||||
ImageOutputPrice: r.ImageOutputPrice,
|
||||
PerRequestPrice: r.PerRequestPrice,
|
||||
Intervals: intervals,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// validatePricingBillingMode 校验计费配置
|
||||
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
|
||||
for _, p := range pricing {
|
||||
// 按次/图片模式必须配置默认价格或区间
|
||||
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||
return errors.New("per-request price or intervals required for per_request/image billing mode")
|
||||
}
|
||||
}
|
||||
// 校验价格不能为负
|
||||
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验 interval:至少有一个价格字段非空
|
||||
for _, iv := range p.Intervals {
|
||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||
iv.PerRequestPrice == nil {
|
||||
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
|
||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePriceNotNegative(field string, val *float64) error {
|
||||
if val != nil && *val < 0 {
|
||||
return fmt.Errorf("%s must be >= 0", field)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokens(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// List handles listing channels with pagination
|
||||
// GET /api/v1/admin/channels
|
||||
func (h *ChannelHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*channelResponse, 0, len(channels))
|
||||
for i := range channels {
|
||||
out = append(out, channelToResponse(&channels[i]))
|
||||
}
|
||||
response.Paginated(c, out, pag.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a channel by ID
|
||||
// GET /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) GetByID(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Create handles creating a new channel
|
||||
// POST /api/v1/admin/channels
|
||||
func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
var req createChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
pricing := pricingRequestToService(req.ModelPricing)
|
||||
if err := validatePricingBillingMode(pricing); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Update handles updating a channel
|
||||
// PUT /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
||||
return
|
||||
}
|
||||
|
||||
var req updateChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
if err := validatePricingBillingMode(pricing); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
input.ModelPricing = &pricing
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Delete handles deleting a channel
|
||||
// DELETE /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Channel deleted successfully"})
|
||||
}
|
||||
|
||||
// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充)
|
||||
// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4
|
||||
func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
|
||||
model := strings.TrimSpace(c.Query("model"))
|
||||
if model == "" {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required").
|
||||
WithMetadata(map[string]string{"param": "model"}))
|
||||
return
|
||||
}
|
||||
|
||||
pricing, err := h.billingService.GetModelPricing(model)
|
||||
if err != nil {
|
||||
// 模型不在定价列表中
|
||||
response.Success(c, gin.H{"found": false})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"found": true,
|
||||
"input_price": pricing.InputPricePerToken,
|
||||
"output_price": pricing.OutputPricePerToken,
|
||||
"cache_write_price": pricing.CacheCreationPricePerToken,
|
||||
"cache_read_price": pricing.CacheReadPricePerToken,
|
||||
"image_output_price": pricing.ImageOutputPricePerToken,
|
||||
})
|
||||
}
|
||||
502
backend/internal/handler/admin/channel_handler_test.go
Normal file
502
backend/internal/handler/admin/channel_handler_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. channelToResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChannelToResponse_NilInput(t *testing.T) {
|
||||
require.Nil(t, channelToResponse(nil))
|
||||
}
|
||||
|
||||
func TestChannelToResponse_FullChannel(t *testing.T) {
|
||||
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
|
||||
ch := &service.Channel{
|
||||
ID: 42,
|
||||
Name: "test-channel",
|
||||
Description: "desc",
|
||||
Status: "active",
|
||||
BillingModelSource: "upstream",
|
||||
RestrictModels: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(time.Hour),
|
||||
GroupIDs: []int64{1, 2, 3},
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 10,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4"},
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.03),
|
||||
CacheWritePrice: float64Ptr(0.005),
|
||||
CacheReadPrice: float64Ptr(0.002),
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(42), resp.ID)
|
||||
require.Equal(t, "test-channel", resp.Name)
|
||||
require.Equal(t, "desc", resp.Description)
|
||||
require.Equal(t, "active", resp.Status)
|
||||
require.Equal(t, "upstream", resp.BillingModelSource)
|
||||
require.True(t, resp.RestrictModels)
|
||||
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
|
||||
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
|
||||
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
|
||||
|
||||
// model mapping
|
||||
require.Len(t, resp.ModelMapping, 1)
|
||||
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
|
||||
|
||||
// pricing
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
p := resp.ModelPricing[0]
|
||||
require.Equal(t, int64(10), p.ID)
|
||||
require.Equal(t, "openai", p.Platform)
|
||||
require.Equal(t, []string{"gpt-4"}, p.Models)
|
||||
require.Equal(t, "token", p.BillingMode)
|
||||
require.Equal(t, float64Ptr(0.01), p.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
|
||||
require.Empty(t, p.Intervals)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
|
||||
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
BillingModelSource: "",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
GroupIDs: nil,
|
||||
ModelMapping: nil,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Platform: "",
|
||||
BillingMode: "",
|
||||
Models: []string{"m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Equal(t, "channel_mapped", resp.BillingModelSource)
|
||||
require.NotNil(t, resp.GroupIDs)
|
||||
require.Empty(t, resp.GroupIDs)
|
||||
require.NotNil(t, resp.ModelMapping)
|
||||
require.Empty(t, resp.ModelMapping)
|
||||
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
|
||||
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_NilModels(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Models: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
require.NotNil(t, resp.ModelPricing[0].Models)
|
||||
require.Empty(t, resp.ModelPricing[0].Models)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_WithIntervals(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
Intervals: []service.PricingInterval{
|
||||
{
|
||||
ID: 100,
|
||||
MinTokens: 0,
|
||||
MaxTokens: intPtr(1000),
|
||||
TierLabel: "1K",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.02),
|
||||
CacheWritePrice: float64Ptr(0.003),
|
||||
CacheReadPrice: float64Ptr(0.001),
|
||||
PerRequestPrice: float64Ptr(0.1),
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
ID: 101,
|
||||
MinTokens: 1000,
|
||||
MaxTokens: nil,
|
||||
TierLabel: "unlimited",
|
||||
SortOrder: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
intervals := resp.ModelPricing[0].Intervals
|
||||
require.Len(t, intervals, 2)
|
||||
|
||||
iv0 := intervals[0]
|
||||
require.Equal(t, int64(100), iv0.ID)
|
||||
require.Equal(t, 0, iv0.MinTokens)
|
||||
require.Equal(t, intPtr(1000), iv0.MaxTokens)
|
||||
require.Equal(t, "1K", iv0.TierLabel)
|
||||
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
|
||||
require.Equal(t, 1, iv0.SortOrder)
|
||||
|
||||
iv1 := intervals[1]
|
||||
require.Equal(t, int64(101), iv1.ID)
|
||||
require.Equal(t, 1000, iv1.MinTokens)
|
||||
require.Nil(t, iv1.MaxTokens)
|
||||
require.Equal(t, "unlimited", iv1.TierLabel)
|
||||
require.Equal(t, 2, iv1.SortOrder)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_MultipleEntries(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "multi",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.003),
|
||||
OutputPrice: float64Ptr(0.015),
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4", "gpt-4o"},
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(1.0),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Platform: "gemini",
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
BillingMode: service.BillingModeImage,
|
||||
ImageOutputPrice: float64Ptr(0.05),
|
||||
PerRequestPrice: float64Ptr(0.2),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 3)
|
||||
|
||||
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
|
||||
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
|
||||
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
|
||||
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
|
||||
|
||||
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
|
||||
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
|
||||
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
|
||||
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
|
||||
|
||||
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
|
||||
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
|
||||
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
|
||||
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
|
||||
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. pricingRequestToService
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestPricingRequestToService_Defaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req channelModelPricingRequest
|
||||
wantField string // which default field to check
|
||||
wantValue string
|
||||
}{
|
||||
{
|
||||
name: "empty billing mode defaults to token",
|
||||
req: channelModelPricingRequest{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "",
|
||||
},
|
||||
wantField: "BillingMode",
|
||||
wantValue: string(service.BillingModeToken),
|
||||
},
|
||||
{
|
||||
name: "empty platform defaults to anthropic",
|
||||
req: channelModelPricingRequest{
|
||||
Models: []string{"m1"},
|
||||
Platform: "",
|
||||
},
|
||||
wantField: "Platform",
|
||||
wantValue: "anthropic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
|
||||
require.Len(t, result, 1)
|
||||
switch tt.wantField {
|
||||
case "BillingMode":
|
||||
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
|
||||
case "Platform":
|
||||
require.Equal(t, tt.wantValue, result[0].Platform)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_WithAllFields(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4", "gpt-4o"},
|
||||
BillingMode: "per_request",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.03),
|
||||
CacheWritePrice: float64Ptr(0.005),
|
||||
CacheReadPrice: float64Ptr(0.002),
|
||||
ImageOutputPrice: float64Ptr(0.04),
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
r := result[0]
|
||||
require.Equal(t, "openai", r.Platform)
|
||||
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
|
||||
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
|
||||
require.Equal(t, float64Ptr(0.01), r.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
|
||||
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_WithIntervals(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "per_request",
|
||||
Intervals: []pricingIntervalRequest{
|
||||
{
|
||||
MinTokens: 0,
|
||||
MaxTokens: intPtr(2000),
|
||||
TierLabel: "small",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.02),
|
||||
CacheWritePrice: float64Ptr(0.003),
|
||||
CacheReadPrice: float64Ptr(0.001),
|
||||
PerRequestPrice: float64Ptr(0.1),
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
MinTokens: 2000,
|
||||
MaxTokens: nil,
|
||||
TierLabel: "large",
|
||||
SortOrder: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
require.Len(t, result[0].Intervals, 2)
|
||||
|
||||
iv0 := result[0].Intervals[0]
|
||||
require.Equal(t, 0, iv0.MinTokens)
|
||||
require.Equal(t, intPtr(2000), iv0.MaxTokens)
|
||||
require.Equal(t, "small", iv0.TierLabel)
|
||||
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
|
||||
require.Equal(t, 1, iv0.SortOrder)
|
||||
|
||||
iv1 := result[0].Intervals[1]
|
||||
require.Equal(t, 2000, iv1.MinTokens)
|
||||
require.Nil(t, iv1.MaxTokens)
|
||||
require.Equal(t, "large", iv1.TierLabel)
|
||||
require.Equal(t, 2, iv1.SortOrder)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_EmptySlice(t *testing.T) {
|
||||
result := pricingRequestToService([]channelModelPricingRequest{})
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "token",
|
||||
// all price fields are nil by default
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
r := result[0]
|
||||
require.Nil(t, r.InputPrice)
|
||||
require.Nil(t, r.OutputPrice)
|
||||
require.Nil(t, r.CacheWritePrice)
|
||||
require.Nil(t, r.CacheReadPrice)
|
||||
require.Nil(t, r.ImageOutputPrice)
|
||||
require.Nil(t, r.PerRequestPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. validatePricingBillingMode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidatePricingBillingMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pricing []service.ChannelModelPricing
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "token mode - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModeToken},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request with price - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request with intervals - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
Intervals: []service.PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request no price no intervals - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModePerRequest},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "image with price - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModeImage,
|
||||
PerRequestPrice: float64Ptr(0.2),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "image no price no intervals - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModeImage},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list - valid",
|
||||
pricing: []service.ChannelModelPricing{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed modes with invalid image - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.01),
|
||||
},
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
{
|
||||
BillingMode: service.BillingModeImage,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validatePricingBillingMode(tt.pricing)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "per-request price or intervals required")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||
dim.Endpoint = c.Query("endpoint")
|
||||
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||
|
||||
// Additional filter conditions
|
||||
if v := c.Query("user_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.UserID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("api_key_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.APIKeyID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("account_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.AccountID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("request_type"); v != "" {
|
||||
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
|
||||
rtVal := int16(rt)
|
||||
dim.RequestType = &rtVal
|
||||
}
|
||||
}
|
||||
if v := c.Query("stream"); v != "" {
|
||||
if s, err := strconv.ParseBool(v); err == nil {
|
||||
dim.Stream = &s
|
||||
}
|
||||
}
|
||||
if v := c.Query("billing_type"); v != "" {
|
||||
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
|
||||
btVal := int8(bt)
|
||||
dim.BillingType = &btVal
|
||||
}
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||
|
||||
@@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
billingMode := strings.TrimSpace(c.Query("billing_mode"))
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
@@ -174,6 +175,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
BillingMode: billingMode,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
ExactTotal: exactTotal,
|
||||
@@ -234,6 +236,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
billingMode := strings.TrimSpace(c.Query("billing_mode"))
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
@@ -312,6 +315,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
BillingMode: billingMode,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
@@ -577,6 +577,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
BillingMode: l.BillingMode,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
@@ -604,6 +605,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
UpstreamModel: l.UpstreamModel,
|
||||
ChannelID: l.ChannelID,
|
||||
ModelMappingChain: l.ModelMappingChain,
|
||||
BillingTier: l.BillingTier,
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
|
||||
@@ -390,6 +390,9 @@ type UsageLog struct {
|
||||
// Cache TTL Override 标记
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
|
||||
|
||||
// BillingMode 计费模式:token/image
|
||||
BillingMode *string `json:"billing_mode,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
@@ -406,6 +409,13 @@ type AdminUsageLog struct {
|
||||
// Omitted when no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
|
||||
// ChannelID 渠道 ID
|
||||
ChannelID *int64 `json:"channel_id,omitempty"`
|
||||
// ModelMappingChain 模型映射链,如 "a→b→c"
|
||||
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
|
||||
// BillingTier 计费层级标签(per_request/image 模式)
|
||||
BillingTier *string `json:"billing_tier,omitempty"`
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
|
||||
@@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqStream := parsedReq.Stream
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
@@ -292,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -478,6 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -514,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -660,6 +664,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
parsedReq.OnUpstreamAccepted = queueRelease
|
||||
// ===== 用户消息串行队列 END =====
|
||||
|
||||
// 应用渠道模型映射到请求
|
||||
if channelMapping.Mapped {
|
||||
parsedReq.Model = channelMapping.MappedModel
|
||||
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
|
||||
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
@@ -810,6 +821,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
|
||||
@@ -80,6 +80,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// Claude Code only restriction
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
@@ -154,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
@@ -203,7 +206,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -255,6 +262,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.cc.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -80,6 +80,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// Claude Code only restriction:
|
||||
// /v1/responses is never a Claude Code endpoint.
|
||||
// When claude_code_only is enabled, this endpoint is rejected.
|
||||
@@ -159,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
@@ -208,7 +211,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -261,6 +268,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.responses.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -161,6 +161,8 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
nil, // tlsFPProfileService
|
||||
nil, // channelService
|
||||
nil, // resolver
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||
reqModel := modelName // 保存映射前的原始模型名
|
||||
if channelMapping.Mapped {
|
||||
modelName = channelMapping.MappedModel
|
||||
}
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
@@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
@@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gemini_v1beta.models"),
|
||||
|
||||
@@ -30,6 +30,7 @@ type AdminHandlers struct {
|
||||
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
Channel *admin.ChannelHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
@@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@@ -257,16 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
|
||||
@@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
return
|
||||
@@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Forward request
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
// 应用渠道模型映射到请求体
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
@@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
|
||||
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
// 应用渠道模型映射到请求体
|
||||
forwardBody := body
|
||||
if channelMappingMsg.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
releaseTurnSlots := func() {
|
||||
@@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
@@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||||
// 应用渠道模型映射到 WebSocket 首条消息
|
||||
wsFirstMessage := firstMessage
|
||||
if channelMappingWS.Mapped {
|
||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
|
||||
@@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
//
|
||||
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
@@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
var lastFailoverHeaders http.Header
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_select_failed",
|
||||
zap.Error(err),
|
||||
|
||||
@@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
nil, // tlsFPProfileService
|
||||
nil, // channelService
|
||||
nil, // resolver
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -33,6 +33,7 @@ func ProvideAdminHandlers(
|
||||
tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
channelHandler *admin.ChannelHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -59,6 +60,7 @@ func ProvideAdminHandlers(
|
||||
TLSFingerprintProfile: tlsFingerprintProfileHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
Channel: channelHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewTLSFingerprintProfileHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
admin.NewChannelHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -125,6 +125,7 @@ type ClaudeUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeError Claude 错误响应
|
||||
|
||||
@@ -149,13 +149,31 @@ type GeminiCandidate struct {
|
||||
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiTokenDetail Gemini token 详情(按模态分类)
|
||||
type GeminiTokenDetail struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
|
||||
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
|
||||
}
|
||||
|
||||
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
|
||||
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
|
||||
for _, d := range m.CandidatesTokensDetails {
|
||||
if d.Modality == "IMAGE" {
|
||||
return d.TokenCount
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||
|
||||
@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
|
||||
}
|
||||
|
||||
// 生成响应 ID
|
||||
|
||||
@@ -32,9 +32,10 @@ type StreamingProcessor struct {
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
imageOutputTokens int
|
||||
}
|
||||
|
||||
// NewStreamingProcessor 创建流式响应处理器
|
||||
@@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
@@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
ImageOutputTokens: p.imageOutputTokens,
|
||||
}
|
||||
|
||||
if !p.messageStartSent {
|
||||
@@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
|
||||
}
|
||||
|
||||
responseID := v1Resp.ResponseID
|
||||
@@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
ImageOutputTokens: p.imageOutputTokens,
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
|
||||
@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
|
||||
ModelType string // "requested", "upstream", or "mapping"
|
||||
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||
EndpointType string // "inbound", "upstream", or "path"
|
||||
// Additional filter conditions
|
||||
UserID int64 // filter by user_id (>0 to enable)
|
||||
APIKeyID int64 // filter by api_key_id (>0 to enable)
|
||||
AccountID int64 // filter by account_id (>0 to enable)
|
||||
RequestType *int16 // filter by request_type (non-nil to enable)
|
||||
Stream *bool // filter by stream flag (non-nil to enable)
|
||||
BillingType *int8 // filter by billing_type (non-nil to enable)
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
@@ -230,6 +237,7 @@ type UsageLogFilters struct {
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
BillingMode string
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
|
||||
|
||||
461
backend/internal/repository/channel_repo.go
Normal file
461
backend/internal/repository/channel_repo.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type channelRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewChannelRepository 创建渠道数据访问实例
|
||||
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
|
||||
return &channelRepository{db: db}
|
||||
}
|
||||
|
||||
// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
|
||||
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return service.ErrChannelExists
|
||||
}
|
||||
return fmt.Errorf("insert channel: %w", err)
|
||||
}
|
||||
|
||||
// 设置分组关联
|
||||
if len(channel.GroupIDs) > 0 {
|
||||
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型定价
|
||||
if len(channel.ModelPricing) > 0 {
|
||||
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
var modelMappingJSON []byte
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.GroupIDs = groupIDs
|
||||
|
||||
pricing, err := r.ListModelPricing(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.ModelPricing = pricing
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
|
||||
WHERE id = $7`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return service.ErrChannelExists
|
||||
}
|
||||
return fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return service.ErrChannelNotFound
|
||||
}
|
||||
|
||||
// 更新分组关联
|
||||
if channel.GroupIDs != nil {
|
||||
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 更新模型定价
|
||||
if channel.ModelPricing != nil {
|
||||
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete channel: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return service.ErrChannelNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
|
||||
where := []string{"1=1"}
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
|
||||
if status != "" {
|
||||
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
|
||||
args = append(args, status)
|
||||
argIdx++
|
||||
}
|
||||
if search != "" {
|
||||
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
|
||||
args = append(args, "%"+escapeLike(search)+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := strings.Join(where, " AND ")
|
||||
|
||||
// 计数
|
||||
var total int64
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, nil, fmt.Errorf("count channels: %w", err)
|
||||
}
|
||||
|
||||
pageSize := params.Limit() // 约束在 [1, 100]
|
||||
page := params.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
|
||||
whereClause, argIdx, argIdx+1,
|
||||
)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("query channels: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var channels []service.Channel
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate channels: %w", err)
|
||||
}
|
||||
|
||||
// 批量加载分组 ID 和模型定价(避免 N+1)
|
||||
if len(channelIDs) > 0 {
|
||||
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
pages := 0
|
||||
if total > 0 {
|
||||
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
}
|
||||
|
||||
paginationResult := &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
}
|
||||
|
||||
return channels, paginationResult, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var channels []service.Channel
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate channels: %w", err)
|
||||
}
|
||||
|
||||
if len(channelIDs) == 0 {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// 批量加载分组 ID
|
||||
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 批量加载模型定价
|
||||
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// --- 批量加载辅助方法 ---
|
||||
|
||||
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
|
||||
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT channel_id, group_id FROM channel_groups
|
||||
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load group ids: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
groupMap := make(map[int64][]int64, len(channelIDs))
|
||||
for rows.Next() {
|
||||
var channelID, groupID int64
|
||||
if err := rows.Scan(&channelID, &groupID); err != nil {
|
||||
return nil, fmt.Errorf("scan group id: %w", err)
|
||||
}
|
||||
groupMap[channelID] = append(groupMap[channelID], groupID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group ids: %w", err)
|
||||
}
|
||||
return groupMap, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// --- 分组关联 ---
|
||||
|
||||
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group ids: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan group id: %w", err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group ids: %w", err)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var channelID int64
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
|
||||
).Scan(&channelID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
return channelID, err
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
|
||||
pq.Array(groupIDs), channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get groups in other channels: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var conflicting []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan conflicting group id: %w", err)
|
||||
}
|
||||
conflicting = append(conflicting, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
|
||||
}
|
||||
return conflicting, nil
|
||||
}
|
||||
|
||||
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
|
||||
// 格式:{"platform": {"src": "dst"}, ...}
|
||||
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal model_mapping: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
|
||||
func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]map[string]string
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return make(map[int64]string), nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
defer rows.Close() //nolint:errcheck
|
||||
|
||||
result := make(map[int64]string, len(groupIDs))
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var platform string
|
||||
if err := rows.Scan(&id, &platform); err != nil {
|
||||
return nil, fmt.Errorf("scan group platform: %w", err)
|
||||
}
|
||||
result[id] = platform
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group platforms: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
291
backend/internal/repository/channel_repo_pricing.go
Normal file
291
backend/internal/repository/channel_repo_pricing.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// --- 模型定价 ---
|
||||
|
||||
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list model pricing: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result, pricingIDs, err := scanModelPricingRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(pricingIDs) > 0 {
|
||||
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range result {
|
||||
result[i].Intervals = intervalMap[result[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
||||
return createModelPricingExec(ctx, r.db, pricing)
|
||||
}
|
||||
|
||||
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx,
|
||||
`UPDATE channel_model_pricing
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
|
||||
WHERE id = $10`,
|
||||
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update model pricing: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete model pricing: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
|
||||
})
|
||||
}
|
||||
|
||||
// --- 批量加载辅助方法 ---
|
||||
|
||||
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
|
||||
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load model pricing: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 按 channelID 分组
|
||||
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
|
||||
for _, p := range allPricing {
|
||||
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
|
||||
}
|
||||
|
||||
// 批量加载所有区间
|
||||
if len(allPricingIDs) > 0 {
|
||||
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for chID := range pricingMap {
|
||||
for i := range pricingMap[chID] {
|
||||
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pricingMap, nil
|
||||
}
|
||||
|
||||
// batchLoadIntervals 批量加载多个定价条目的区间
|
||||
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
|
||||
input_price, output_price, cache_write_price, cache_read_price,
|
||||
per_request_price, sort_order, created_at, updated_at
|
||||
FROM channel_pricing_intervals
|
||||
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
|
||||
pq.Array(pricingIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load intervals: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
|
||||
for rows.Next() {
|
||||
var iv service.PricingInterval
|
||||
if err := rows.Scan(
|
||||
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
|
||||
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
|
||||
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan interval: %w", err)
|
||||
}
|
||||
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate intervals: %w", err)
|
||||
}
|
||||
return intervalMap, nil
|
||||
}
|
||||
|
||||
// --- 共享 scan 辅助 ---
|
||||
|
||||
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
|
||||
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
|
||||
var result []service.ChannelModelPricing
|
||||
var pricingIDs []int64
|
||||
for rows.Next() {
|
||||
var p service.ChannelModelPricing
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||
p.Models = []string{}
|
||||
}
|
||||
pricingIDs = append(pricingIDs, p.ID)
|
||||
result = append(result, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
|
||||
}
|
||||
return result, pricingIDs, nil
|
||||
}
|
||||
|
||||
// --- 事务内辅助方法 ---
|
||||
|
||||
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
|
||||
type dbExec interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
|
||||
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
|
||||
return fmt.Errorf("delete old group associations: %w", err)
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := exec.ExecContext(ctx,
|
||||
`INSERT INTO channel_groups (channel_id, group_id)
|
||||
SELECT $1, unnest($2::bigint[])`,
|
||||
channelID, pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert group associations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := pricing.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
}
|
||||
err = exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, platform, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert model pricing: %w", err)
|
||||
}
|
||||
|
||||
for i := range pricing.Intervals {
|
||||
pricing.Intervals[i].PricingID = pricing.ID
|
||||
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
|
||||
return exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_pricing_intervals
|
||||
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
|
||||
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
|
||||
iv.PerRequestPrice, iv.SortOrder,
|
||||
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
|
||||
}
|
||||
|
||||
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
|
||||
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
|
||||
return fmt.Errorf("delete old model pricing: %w", err)
|
||||
}
|
||||
for i := range pricingList {
|
||||
pricingList[i].ChannelID = channelID
|
||||
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
|
||||
return fmt.Errorf("insert model pricing: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isUniqueViolation 检查 pq 唯一约束违反错误
|
||||
func isUniqueViolation(err error) bool {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr != nil {
|
||||
return pqErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
|
||||
func escapeLike(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return s
|
||||
}
|
||||
227
backend/internal/repository/channel_repo_test.go
Normal file
227
backend/internal/repository/channel_repo_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- marshalModelMapping ---
|
||||
|
||||
func TestMarshalModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]map[string]string
|
||||
wantJSON string // expected JSON output (exact match)
|
||||
}{
|
||||
{
|
||||
name: "empty map",
|
||||
input: map[string]map[string]string{},
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "nil map",
|
||||
input: nil,
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "populated map",
|
||||
input: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested values",
|
||||
input: map[string]map[string]string{
|
||||
"openai": {"*": "gpt-5.4"},
|
||||
"anthropic": {"claude-old": "claude-new"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := marshalModelMapping(tt.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantJSON != "" {
|
||||
require.Equal(t, []byte(tt.wantJSON), result)
|
||||
} else {
|
||||
// round-trip: unmarshal and compare with input
|
||||
var parsed map[string]map[string]string
|
||||
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||
require.Equal(t, tt.input, parsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- unmarshalModelMapping ---
|
||||
|
||||
func TestUnmarshalModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantNil bool
|
||||
want map[string]map[string]string
|
||||
}{
|
||||
{
|
||||
name: "nil data",
|
||||
input: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
input: []byte{},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: []byte("not-json"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "type error - number",
|
||||
input: []byte("42"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "type error - array",
|
||||
input: []byte("[1,2,3]"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "valid JSON",
|
||||
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
|
||||
want: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
"anthropic": {"old": "new"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
input: []byte("{}"),
|
||||
want: map[string]map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := unmarshalModelMapping(tt.input)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.want, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- escapeLike ---
|
||||
|
||||
func TestEscapeLike(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no special chars",
|
||||
input: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "backslash",
|
||||
input: `a\b`,
|
||||
want: `a\\b`,
|
||||
},
|
||||
{
|
||||
name: "percent",
|
||||
input: "50%",
|
||||
want: `50\%`,
|
||||
},
|
||||
{
|
||||
name: "underscore",
|
||||
input: "a_b",
|
||||
want: `a\_b`,
|
||||
},
|
||||
{
|
||||
name: "all special chars",
|
||||
input: `a\b%c_d`,
|
||||
want: `a\\b\%c\_d`,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "consecutive special chars",
|
||||
input: "%_%",
|
||||
want: `\%\_\%`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, escapeLike(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- isUniqueViolation ---
|
||||
|
||||
func TestIsUniqueViolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "unique violation code 23505",
|
||||
err: &pq.Error{Code: "23505"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different pq error code",
|
||||
err: &pq.Error{Code: "23503"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-pq error",
|
||||
err: errors.New("some generic error"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "typed nil pq.Error",
|
||||
err: func() error {
|
||||
var pqErr *pq.Error
|
||||
return pqErr
|
||||
}(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "bare nil",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wrapped pq error with 23505",
|
||||
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, isUniqueViolation(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
||||
|
||||
// usageLogInsertArgTypes must stay in the same order as:
|
||||
// 1. prepareUsageLogInsert().args
|
||||
@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"integer", // cache_read_tokens
|
||||
"integer", // cache_creation_5m_tokens
|
||||
"integer", // cache_creation_1h_tokens
|
||||
"integer", // image_output_tokens
|
||||
"numeric", // image_output_cost
|
||||
"numeric", // input_cost
|
||||
"numeric", // output_cost
|
||||
"numeric", // cache_creation_cost
|
||||
@@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text", // inbound_endpoint
|
||||
"text", // upstream_endpoint
|
||||
"boolean", // cache_ttl_overridden
|
||||
"bigint", // channel_id
|
||||
"text", // model_mapping_chain
|
||||
"text", // billing_tier
|
||||
"text", // billing_mode
|
||||
"timestamptz", // created_at
|
||||
}
|
||||
|
||||
@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15,
|
||||
$16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*39)
|
||||
args := make([]any, 0, len(keys)*47)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*40)
|
||||
args := make([]any, 0, len(preparedList)*46)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
channel_id,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15,
|
||||
$16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||
channelID := nullInt64(log.ChannelID)
|
||||
modelMappingChain := nullString(log.ModelMappingChain)
|
||||
billingTier := nullString(log.BillingTier)
|
||||
billingMode := nullString(log.BillingMode)
|
||||
requestedModel := strings.TrimSpace(log.RequestedModel)
|
||||
if requestedModel == "" {
|
||||
requestedModel = strings.TrimSpace(log.Model)
|
||||
@@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.ImageOutputTokens,
|
||||
log.ImageOutputCost,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
@@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
inboundEndpoint,
|
||||
upstreamEndpoint,
|
||||
log.CacheTTLOverridden,
|
||||
channelID,
|
||||
modelMappingChain,
|
||||
billingTier,
|
||||
billingMode,
|
||||
createdAt,
|
||||
},
|
||||
}
|
||||
@@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
|
||||
|
||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
conditions := make([]string, 0, 8)
|
||||
args := make([]any, 0, 8)
|
||||
conditions := make([]string, 0, 9)
|
||||
args := make([]any, 0, 9)
|
||||
|
||||
if filters.UserID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
@@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
}
|
||||
if filters.BillingMode != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
|
||||
args = append(args, filters.BillingMode)
|
||||
}
|
||||
if filters.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||||
args = append(args, *filters.StartTime)
|
||||
@@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
|
||||
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
|
||||
args = append(args, dim.Endpoint)
|
||||
}
|
||||
if dim.UserID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
|
||||
args = append(args, dim.UserID)
|
||||
}
|
||||
if dim.APIKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
|
||||
args = append(args, dim.APIKeyID)
|
||||
}
|
||||
if dim.AccountID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
|
||||
args = append(args, dim.AccountID)
|
||||
}
|
||||
if dim.RequestType != nil {
|
||||
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
|
||||
args = append(args, *dim.RequestType)
|
||||
}
|
||||
if dim.Stream != nil {
|
||||
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
|
||||
args = append(args, *dim.Stream)
|
||||
}
|
||||
if dim.BillingType != nil {
|
||||
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
|
||||
args = append(args, *dim.BillingType)
|
||||
}
|
||||
|
||||
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
|
||||
if limit > 0 {
|
||||
@@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
}
|
||||
if filters.BillingMode != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
|
||||
args = append(args, filters.BillingMode)
|
||||
}
|
||||
if filters.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||||
args = append(args, *filters.StartTime)
|
||||
@@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
cacheReadTokens int
|
||||
cacheCreation5m int
|
||||
cacheCreation1h int
|
||||
imageOutputTokens int
|
||||
imageOutputCost float64
|
||||
inputCost float64
|
||||
outputCost float64
|
||||
cacheCreationCost float64
|
||||
@@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
inboundEndpoint sql.NullString
|
||||
upstreamEndpoint sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
channelID sql.NullInt64
|
||||
modelMappingChain sql.NullString
|
||||
billingTier sql.NullString
|
||||
billingMode sql.NullString
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&cacheReadTokens,
|
||||
&cacheCreation5m,
|
||||
&cacheCreation1h,
|
||||
&imageOutputTokens,
|
||||
&imageOutputCost,
|
||||
&inputCost,
|
||||
&outputCost,
|
||||
&cacheCreationCost,
|
||||
@@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&inboundEndpoint,
|
||||
&upstreamEndpoint,
|
||||
&cacheTTLOverridden,
|
||||
&channelID,
|
||||
&modelMappingChain,
|
||||
&billingTier,
|
||||
&billingMode,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
CacheReadTokens: cacheReadTokens,
|
||||
CacheCreation5mTokens: cacheCreation5m,
|
||||
CacheCreation1hTokens: cacheCreation1h,
|
||||
ImageOutputTokens: imageOutputTokens,
|
||||
ImageOutputCost: imageOutputCost,
|
||||
InputCost: inputCost,
|
||||
OutputCost: outputCost,
|
||||
CacheCreationCost: cacheCreationCost,
|
||||
@@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if upstreamModel.Valid {
|
||||
log.UpstreamModel = &upstreamModel.String
|
||||
}
|
||||
if channelID.Valid {
|
||||
value := channelID.Int64
|
||||
log.ChannelID = &value
|
||||
}
|
||||
if modelMappingChain.Valid {
|
||||
log.ModelMappingChain = &modelMappingChain.String
|
||||
}
|
||||
if billingTier.Valid {
|
||||
log.BillingTier = &billingTier.String
|
||||
}
|
||||
if billingMode.Valid {
|
||||
log.BillingMode = &billingMode.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.ImageOutputTokens,
|
||||
log.ImageOutputCost,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
@@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
sqlmock.AnyArg(), // upstream_endpoint
|
||||
log.CacheTTLOverridden,
|
||||
sqlmock.AnyArg(), // channel_id
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||
@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.ImageOutputTokens,
|
||||
log.ImageOutputCost,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
@@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
sqlmock.AnyArg(), // channel_id
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
@@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
4, // cache_read_tokens
|
||||
5, // cache_creation_5m_tokens
|
||||
6, // cache_creation_1h_tokens
|
||||
0, // image_output_tokens
|
||||
0.0, // image_output_cost
|
||||
0.1, // input_cost
|
||||
0.2, // output_cost
|
||||
0.3, // cache_creation_cost
|
||||
@@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0, 0.0, // image_output_tokens, image_output_cost
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
@@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0, 0.0, // image_output_tokens, image_output_cost
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
@@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserGroupRateRepository,
|
||||
NewErrorPassthroughRepository,
|
||||
NewTLSFingerprintProfileRepository,
|
||||
NewChannelRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
@@ -87,6 +87,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
|
||||
// 渠道管理
|
||||
registerChannelRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
|
||||
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
|
||||
}
|
||||
}
|
||||
|
||||
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
channels := admin.Group("/channels")
|
||||
{
|
||||
channels.GET("", h.Admin.Channel.List)
|
||||
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
|
||||
channels.GET("/:id", h.Admin.Channel.GetByID)
|
||||
channels.POST("", h.Admin.Channel.Create)
|
||||
channels.PUT("/:id", h.Admin.Channel.Update)
|
||||
channels.DELETE("/:id", h.Admin.Channel.Delete)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ type ModelPricing struct {
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -94,16 +95,19 @@ type UsageTokens struct {
|
||||
CacheReadTokens int
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
ImageOutputTokens int
|
||||
}
|
||||
|
||||
// CostBreakdown 费用明细
|
||||
type CostBreakdown struct {
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
ImageOutputCost float64
|
||||
CacheCreationCost float64
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
ActualCost float64 // 应用倍率后的实际费用
|
||||
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
|
||||
}
|
||||
|
||||
// BillingService 计费服务
|
||||
@@ -357,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
@@ -371,81 +376,252 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
||||
// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价
|
||||
func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelPricing == nil {
|
||||
return pricing, nil
|
||||
}
|
||||
if channelPricing.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *channelPricing.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *channelPricing.InputPrice
|
||||
}
|
||||
if channelPricing.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *channelPricing.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice
|
||||
}
|
||||
if channelPricing.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice
|
||||
}
|
||||
if channelPricing.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
|
||||
}
|
||||
if channelPricing.ImageOutputPrice != nil {
|
||||
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
|
||||
}
|
||||
return pricing, nil
|
||||
}
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
// --- 统一计费入口 ---
|
||||
|
||||
// CostInput 统一计费输入
|
||||
type CostInput struct {
|
||||
Ctx context.Context
|
||||
Model string
|
||||
GroupID *int64 // 用于渠道定价查找
|
||||
Tokens UsageTokens
|
||||
RequestCount int // 按次计费时使用
|
||||
SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
|
||||
RateMultiplier float64
|
||||
ServiceTier string // "priority","flex","" 等
|
||||
Resolver *ModelPricingResolver // 定价解析器
|
||||
Resolved *ResolvedPricing // 可选:预解析的定价结果(避免重复 Resolve 调用)
|
||||
}
|
||||
|
||||
// CalculateCostUnified 统一计费入口,支持三种计费模式。
|
||||
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
|
||||
func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) {
|
||||
if input.Resolver == nil {
|
||||
// 无 Resolver,回退到旧路径
|
||||
return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil)
|
||||
}
|
||||
|
||||
// 优先使用预解析结果,避免重复 Resolve 调用
|
||||
resolved := input.Resolved
|
||||
if resolved == nil {
|
||||
resolved = input.Resolver.Resolve(input.Ctx, PricingInput{
|
||||
Model: input.Model,
|
||||
GroupID: input.GroupID,
|
||||
})
|
||||
}
|
||||
|
||||
if input.RateMultiplier <= 0 {
|
||||
input.RateMultiplier = 1.0
|
||||
}
|
||||
|
||||
var breakdown *CostBreakdown
|
||||
var err error
|
||||
switch resolved.Mode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
breakdown, err = s.calculatePerRequestCost(resolved, input)
|
||||
default: // BillingModeToken
|
||||
breakdown, err = s.calculateTokenCost(resolved, input)
|
||||
}
|
||||
if err == nil && breakdown != nil {
|
||||
breakdown.BillingMode = string(resolved.Mode)
|
||||
if breakdown.BillingMode == "" {
|
||||
breakdown.BillingMode = string(BillingModeToken)
|
||||
}
|
||||
}
|
||||
return breakdown, err
|
||||
}
|
||||
|
||||
// calculateTokenCost 按 token 区间计费
|
||||
func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
|
||||
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
||||
if pricing == nil {
|
||||
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
|
||||
}
|
||||
|
||||
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
||||
|
||||
// 长上下文定价仅在无区间定价时应用(区间定价已包含上下文分层)
|
||||
applyLongCtx := len(resolved.Intervals) == 0
|
||||
|
||||
return s.computeTokenBreakdown(pricing, input.Tokens, input.RateMultiplier, input.ServiceTier, applyLongCtx), nil
|
||||
}
|
||||
|
||||
// computeTokenBreakdown 是 token 计费的核心逻辑,由 calculateTokenCost 和 calculateCostInternal 共用。
|
||||
// applyLongCtx 控制是否检查长上下文定价(区间定价已自含上下文分层,不需要额外应用)。
|
||||
func (s *BillingService) computeTokenBreakdown(
|
||||
pricing *ModelPricing, tokens UsageTokens,
|
||||
rateMultiplier float64, serviceTier string,
|
||||
applyLongCtx bool,
|
||||
) *CostBreakdown {
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
|
||||
inputPrice := pricing.InputPricePerToken
|
||||
outputPrice := pricing.OutputPricePerToken
|
||||
cacheReadPrice := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
|
||||
if usePriorityServiceTierPricing(serviceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
inputPrice = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
outputPrice = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
cacheReadPrice = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(serviceTier)
|
||||
}
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
|
||||
if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPrice *= pricing.LongContextInputMultiplier
|
||||
outputPrice *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
|
||||
bd := &CostBreakdown{}
|
||||
bd.InputCost = float64(tokens.InputTokens) * inputPrice
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
|
||||
// 分离图片输出 token 与文本输出 token
|
||||
textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
|
||||
if textOutputTokens < 0 {
|
||||
textOutputTokens = 0
|
||||
}
|
||||
bd.OutputCost = float64(textOutputTokens) * outputPrice
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
// 图片输出 token 费用(独立费率)
|
||||
if tokens.ImageOutputTokens > 0 {
|
||||
imgPrice := pricing.ImageOutputPricePerToken
|
||||
if imgPrice == 0 {
|
||||
imgPrice = outputPrice // 回退到常规输出价格
|
||||
}
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
// 缓存创建费用
|
||||
bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens)
|
||||
|
||||
bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
bd.InputCost *= tierMultiplier
|
||||
bd.OutputCost *= tierMultiplier
|
||||
bd.ImageOutputCost *= tierMultiplier
|
||||
bd.CacheCreationCost *= tierMultiplier
|
||||
bd.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
bd.TotalCost = bd.InputCost + bd.OutputCost + bd.ImageOutputCost +
|
||||
bd.CacheCreationCost + bd.CacheReadCost
|
||||
bd.ActualCost = bd.TotalCost * rateMultiplier
|
||||
|
||||
// 应用倍率计算实际费用
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
return bd
|
||||
}
|
||||
|
||||
// computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。
|
||||
func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 {
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||
return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
}
|
||||
return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
|
||||
return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
return breakdown, nil
|
||||
// calculatePerRequestCost 按次/图片计费
|
||||
func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
count := input.RequestCount
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
|
||||
var unitPrice float64
|
||||
|
||||
if input.SizeTier != "" {
|
||||
unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier)
|
||||
}
|
||||
|
||||
if unitPrice == 0 {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext)
|
||||
}
|
||||
|
||||
// 回退到默认按次价格
|
||||
if unitPrice == 0 {
|
||||
unitPrice = resolved.DefaultPerRequestPrice
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(count)
|
||||
actualCost := totalCost * input.RateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
var pricing *ModelPricing
|
||||
var err error
|
||||
if channelPricing != nil {
|
||||
pricing, err = s.GetModelPricingWithChannel(model, channelPricing)
|
||||
} else {
|
||||
pricing, err = s.GetModelPricing(model)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 旧路径始终检查长上下文定价(无区间定价概念)
|
||||
return s.computeTokenBreakdown(pricing, tokens, rateMultiplier, serviceTier, true), nil
|
||||
}
|
||||
|
||||
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
|
||||
@@ -541,6 +717,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||
ImageOutputTokens: tokens.ImageOutputTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
@@ -561,6 +738,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
return &CostBreakdown{
|
||||
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||
OutputCost: inRangeCost.OutputCost,
|
||||
ImageOutputCost: inRangeCost.ImageOutputCost,
|
||||
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||
@@ -662,8 +840,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
BillingMode: string(BillingModeImage),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
277
backend/internal/service/channel.go
Normal file
277
backend/internal/service/channel.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BillingMode 计费模式
|
||||
type BillingMode string
|
||||
|
||||
const (
|
||||
BillingModeToken BillingMode = "token" // 按 token 区间计费
|
||||
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
|
||||
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
|
||||
)
|
||||
|
||||
// IsValid 检查 BillingMode 是否为合法值
|
||||
func (m BillingMode) IsValid() bool {
|
||||
switch m {
|
||||
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
BillingModelSourceRequested = "requested"
|
||||
BillingModelSourceUpstream = "upstream"
|
||||
BillingModelSourceChannelMapped = "channel_mapped"
|
||||
)
|
||||
|
||||
// Channel 渠道实体
|
||||
type Channel struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
Status string
|
||||
BillingModelSource string // "requested", "upstream", or "channel_mapped"
|
||||
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联的分组 ID 列表
|
||||
GroupIDs []int64
|
||||
// 模型定价列表(每条含 Platform 字段)
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
type ChannelModelPricing struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Platform string // 所属平台(anthropic/openai/gemini/...)
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
PerRequestPrice *float64 // 默认按次计费价格(USD)
|
||||
Intervals []PricingInterval // 区间定价列表
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层)
|
||||
type PricingInterval struct {
|
||||
ID int64
|
||||
PricingID int64
|
||||
MinTokens int // 区间下界(含)
|
||||
MaxTokens *int // 区间上界(不含),nil = 无上限
|
||||
TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等)
|
||||
InputPrice *float64 // token 模式:每 token 输入价
|
||||
OutputPrice *float64 // token 模式:每 token 输出价
|
||||
CacheWritePrice *float64 // token 模式:缓存写入价
|
||||
CacheReadPrice *float64 // token 模式:缓存读取价
|
||||
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
|
||||
SortOrder int
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// IsActive 判断渠道是否启用
|
||||
func (c *Channel) IsActive() bool {
|
||||
return c.Status == StatusActive
|
||||
}
|
||||
|
||||
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
|
||||
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
|
||||
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
for i := range c.ModelPricing {
|
||||
for _, m := range c.ModelPricing[i].Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
cp := c.ModelPricing[i].Clone()
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
|
||||
// 区间为左开右闭 (min, max]:min 不含,max 包含。
|
||||
// 第一个区间 min=0 时,0 token 不匹配任何区间(回退到默认价格)。
|
||||
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
|
||||
for i := range intervals {
|
||||
iv := &intervals[i]
|
||||
if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) {
|
||||
return iv
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
|
||||
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
|
||||
return FindMatchingInterval(p.Intervals, totalTokens)
|
||||
}
|
||||
|
||||
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
|
||||
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
|
||||
labelLower := strings.ToLower(label)
|
||||
for i := range p.Intervals {
|
||||
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
|
||||
return &p.Intervals[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
|
||||
func (p ChannelModelPricing) Clone() ChannelModelPricing {
|
||||
cp := p
|
||||
if p.Models != nil {
|
||||
cp.Models = make([]string, len(p.Models))
|
||||
copy(cp.Models, p.Models)
|
||||
}
|
||||
if p.Intervals != nil {
|
||||
cp.Intervals = make([]PricingInterval, len(p.Intervals))
|
||||
copy(cp.Intervals, p.Intervals)
|
||||
}
|
||||
return cp
|
||||
}
|
||||
|
||||
// Clone 返回 Channel 的深拷贝
|
||||
func (c *Channel) Clone() *Channel {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *c
|
||||
if c.GroupIDs != nil {
|
||||
cp.GroupIDs = make([]int64, len(c.GroupIDs))
|
||||
copy(cp.GroupIDs, c.GroupIDs)
|
||||
}
|
||||
if c.ModelPricing != nil {
|
||||
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
|
||||
for i := range c.ModelPricing {
|
||||
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
|
||||
}
|
||||
}
|
||||
if c.ModelMapping != nil {
|
||||
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
|
||||
for platform, mapping := range c.ModelMapping {
|
||||
inner := make(map[string]string, len(mapping))
|
||||
for k, v := range mapping {
|
||||
inner[k] = v
|
||||
}
|
||||
cp.ModelMapping[platform] = inner
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
// ValidateIntervals 校验区间列表的合法性。
|
||||
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
|
||||
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
|
||||
// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。
|
||||
func ValidateIntervals(intervals []PricingInterval) error {
|
||||
if len(intervals) == 0 {
|
||||
return nil
|
||||
}
|
||||
sorted := make([]PricingInterval, len(intervals))
|
||||
copy(sorted, intervals)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].MinTokens < sorted[j].MinTokens
|
||||
})
|
||||
|
||||
for i := range sorted {
|
||||
if err := validateSingleInterval(&sorted[i], i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return validateIntervalOverlap(sorted)
|
||||
}
|
||||
|
||||
// validateSingleInterval 校验单个区间的字段合法性
|
||||
func validateSingleInterval(iv *PricingInterval, idx int) error {
|
||||
if iv.MinTokens < 0 {
|
||||
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
|
||||
}
|
||||
if iv.MaxTokens != nil {
|
||||
if *iv.MaxTokens <= 0 {
|
||||
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
|
||||
}
|
||||
if *iv.MaxTokens <= iv.MinTokens {
|
||||
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
|
||||
idx+1, *iv.MaxTokens, iv.MinTokens)
|
||||
}
|
||||
}
|
||||
return validateIntervalPrices(iv, idx)
|
||||
}
|
||||
|
||||
// validateIntervalPrices 校验区间内所有价格字段 >= 0
|
||||
func validateIntervalPrices(iv *PricingInterval, idx int) error {
|
||||
prices := []struct {
|
||||
name string
|
||||
val *float64
|
||||
}{
|
||||
{"input_price", iv.InputPrice},
|
||||
{"output_price", iv.OutputPrice},
|
||||
{"cache_write_price", iv.CacheWritePrice},
|
||||
{"cache_read_price", iv.CacheReadPrice},
|
||||
{"per_request_price", iv.PerRequestPrice},
|
||||
}
|
||||
for _, p := range prices {
|
||||
if p.val != nil && *p.val < 0 {
|
||||
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
|
||||
func validateIntervalOverlap(sorted []PricingInterval) error {
|
||||
for i, iv := range sorted {
|
||||
// 无界区间必须是最后一个
|
||||
if iv.MaxTokens == nil && i < len(sorted)-1 {
|
||||
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
|
||||
i+1)
|
||||
}
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
prev := sorted[i-1]
|
||||
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
|
||||
// (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max]
|
||||
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
|
||||
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
|
||||
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokensLabel(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
|
||||
type ChannelUsageFields struct {
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
}
|
||||
857
backend/internal/service/channel_service.go
Normal file
857
backend/internal/service/channel_service.go
Normal file
@@ -0,0 +1,857 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrGroupAlreadyInChannel = infraerrors.Conflict(
|
||||
"GROUP_ALREADY_IN_CHANNEL",
|
||||
"one or more groups already belong to another channel",
|
||||
)
|
||||
)
|
||||
|
||||
// ChannelRepository 渠道数据访问接口
|
||||
type ChannelRepository interface {
|
||||
Create(ctx context.Context, channel *Channel) error
|
||||
GetByID(ctx context.Context, id int64) (*Channel, error)
|
||||
Update(ctx context.Context, channel *Channel) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||||
ListAll(ctx context.Context) ([]Channel, error)
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error)
|
||||
|
||||
// 分组关联
|
||||
GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error)
|
||||
SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error
|
||||
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||||
|
||||
// 分组平台查询
|
||||
GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error)
|
||||
|
||||
// 模型定价
|
||||
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||||
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
DeleteModelPricing(ctx context.Context, id int64) error
|
||||
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
}
|
||||
|
||||
// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突)
|
||||
type channelModelKey struct {
|
||||
groupID int64
|
||||
platform string // 平台标识
|
||||
model string // lowercase
|
||||
}
|
||||
|
||||
// channelGroupPlatformKey 通配符定价缓存键
|
||||
type channelGroupPlatformKey struct {
|
||||
groupID int64
|
||||
platform string
|
||||
}
|
||||
|
||||
// wildcardPricingEntry 通配符定价条目
|
||||
type wildcardPricingEntry struct {
|
||||
prefix string
|
||||
pricing *ChannelModelPricing
|
||||
}
|
||||
|
||||
// wildcardMappingEntry 通配符映射条目
|
||||
type wildcardMappingEntry struct {
|
||||
prefix string
|
||||
target string
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
loadedAt time.Time
|
||||
}
|
||||
|
||||
// ChannelMappingResult 渠道映射查找结果
|
||||
type ChannelMappingResult struct {
|
||||
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道关联)
|
||||
Mapped bool // 是否发生了映射
|
||||
BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped")
|
||||
}
|
||||
|
||||
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
|
||||
// reqModel: 客户端请求的原始模型名。
|
||||
// upstreamModel: 上游实际使用的模型名(ForwardResult.UpstreamModel)。
|
||||
// 返回空字符串表示无映射。
|
||||
func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel string) string {
|
||||
if !r.Mapped {
|
||||
if upstreamModel != "" && upstreamModel != reqModel {
|
||||
return reqModel + "→" + upstreamModel
|
||||
}
|
||||
return ""
|
||||
}
|
||||
if upstreamModel != "" && upstreamModel != r.MappedModel {
|
||||
return reqModel + "→" + r.MappedModel + "→" + upstreamModel
|
||||
}
|
||||
return reqModel + "→" + r.MappedModel
|
||||
}
|
||||
|
||||
// ToUsageFields 将渠道映射结果转为使用记录字段
|
||||
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
|
||||
channelMappedModel := reqModel
|
||||
if r.Mapped {
|
||||
channelMappedModel = r.MappedModel
|
||||
}
|
||||
return ChannelUsageFields{
|
||||
ChannelID: r.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
ChannelMappedModel: channelMappedModel,
|
||||
BillingModelSource: r.BillingModelSource,
|
||||
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
channelCacheTTL = 10 * time.Minute
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// ChannelService 渠道管理服务
|
||||
type ChannelService struct {
|
||||
repo ChannelRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
|
||||
cache atomic.Value // *channelCache
|
||||
cacheSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewChannelService 创建渠道服务实例
|
||||
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
|
||||
s := &ChannelService{
|
||||
repo: repo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// loadCache 加载或返回缓存的渠道数据
|
||||
func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
|
||||
if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil {
|
||||
if time.Since(cached.loadedAt) < channelCacheTTL {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) {
|
||||
// 双重检查
|
||||
if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil {
|
||||
if time.Since(cached.loadedAt) < channelCacheTTL {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
return s.buildCache(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache, ok := result.(*channelCache)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected cache type")
|
||||
}
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化)
|
||||
func newEmptyChannelCache() *channelCache {
|
||||
return &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
}
|
||||
}
|
||||
|
||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
|
||||
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
|
||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
||||
continue // 跳过非本平台的定价
|
||||
}
|
||||
// 使用定价条目的原始平台作为缓存 key,防止跨平台同名模型冲突
|
||||
pricingPlatform := pricing.Platform
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform}
|
||||
for _, model := range pricing.Models {
|
||||
if strings.HasSuffix(model, "*") {
|
||||
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
||||
prefix: prefix,
|
||||
pricing: pricing,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
|
||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 使用映射条目的原始平台作为缓存 key,防止跨平台同名映射冲突
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform}
|
||||
for src, dst := range platformMapping {
|
||||
if strings.HasSuffix(src, "*") {
|
||||
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
||||
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
|
||||
prefix: prefix,
|
||||
target: dst,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)}
|
||||
cache.mappingByGroupModel[key] = dst
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
channels, err := s.repo.ListAll(dbCtx)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
// 收集所有 groupID,批量查询 platform
|
||||
var allGroupIDs []int64
|
||||
for i := range channels {
|
||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||
}
|
||||
groupPlatforms := make(map[int64]string)
|
||||
if len(allGroupIDs) > 0 {
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cache := newEmptyChannelCache()
|
||||
cache.groupPlatform = groupPlatforms
|
||||
cache.byID = make(map[int64]*Channel, len(channels))
|
||||
cache.loadedAt = time.Now()
|
||||
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.channelByGroupID[gid] = ch
|
||||
platform := groupPlatforms[gid]
|
||||
expandPricingToCache(cache, ch, gid, platform)
|
||||
expandMappingToCache(cache, ch, gid, platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||
|
||||
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
|
||||
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
|
||||
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
|
||||
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
|
||||
if groupPlatform == pricingPlatform {
|
||||
return true
|
||||
}
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
|
||||
func matchingPlatforms(groupPlatform string) []string {
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
|
||||
}
|
||||
return []string{groupPlatform}
|
||||
}
|
||||
func (s *ChannelService) invalidateCache() {
|
||||
s.cache.Store((*channelCache)(nil))
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
|
||||
// 主动重建缓存,确保 CRUD 后立即生效
|
||||
if _, err := s.buildCache(context.Background()); err != nil {
|
||||
slog.Warn("failed to rebuild channel cache after invalidation", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardByGroupPlatform[gpKey]
|
||||
for _, wc := range wildcards {
|
||||
if strings.HasPrefix(modelLower, wc.prefix) {
|
||||
return wc.pricing
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardMappingByGP[gpKey]
|
||||
for _, wc := range wildcards {
|
||||
if strings.HasPrefix(modelLower, wc.prefix) {
|
||||
return wc.target
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
|
||||
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
|
||||
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
|
||||
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
|
||||
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||
if pricing, ok := cache.pricingByGroupModel[key]; ok {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
// 精确查找全部失败,依次尝试通配符匹配
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
|
||||
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
||||
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||
if mapped, ok := cache.mappingByGroupModel[key]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// channelLookup 热路径公共查找结果
|
||||
type channelLookup struct {
|
||||
cache *channelCache
|
||||
channel *Channel
|
||||
platform string
|
||||
}
|
||||
|
||||
// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。
|
||||
// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。
|
||||
func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return nil, nil
|
||||
}
|
||||
return &channelLookup{
|
||||
cache: cache,
|
||||
channel: ch,
|
||||
platform: cache.groupPlatform[groupID],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
||||
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
|
||||
// 确保跨平台同名模型各自独立匹配。
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
|
||||
return nil
|
||||
}
|
||||
if lk == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelLower := strings.ToLower(model)
|
||||
pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower)
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cp := pricing.Clone()
|
||||
return &cp
|
||||
}
|
||||
|
||||
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
||||
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
||||
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load channel cache for mapping", "group_id", groupID, "error", err)
|
||||
}
|
||||
if lk == nil {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
return resolveMapping(lk, groupID, model)
|
||||
}
|
||||
|
||||
// IsModelRestricted 检查模型是否被渠道限制。
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||
if lk == nil {
|
||||
return false
|
||||
}
|
||||
return checkRestricted(lk, groupID, model)
|
||||
}
|
||||
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射。
|
||||
// 返回映射结果。模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
|
||||
// restricted 始终返回 false,保留签名兼容性。
|
||||
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||
if groupID == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
lk, _ := s.lookupGroupChannel(ctx, *groupID)
|
||||
if lk == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
return resolveMapping(lk, *groupID, model), false
|
||||
}
|
||||
|
||||
// resolveMapping 基于已查找的渠道信息解析模型映射。
|
||||
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
|
||||
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
|
||||
result := ChannelMappingResult{
|
||||
MappedModel: model,
|
||||
ChannelID: lk.channel.ID,
|
||||
BillingModelSource: lk.channel.BillingModelSource,
|
||||
}
|
||||
if result.BillingModelSource == "" {
|
||||
result.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
modelLower := strings.ToLower(model)
|
||||
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
|
||||
result.MappedModel = mapped
|
||||
result.Mapped = true
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
||||
// antigravity 分组依次尝试所有匹配平台的定价列表。
|
||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||
if !lk.channel.RestrictModels {
|
||||
return false
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
// 使用与查找定价相同的跨平台逻辑
|
||||
if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
|
||||
func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
||||
return body
|
||||
}
|
||||
newBody, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
// Create 创建渠道
|
||||
func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) {
|
||||
exists, err := s.repo.ExistsByName(ctx, input.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if len(input.GroupIDs) > 0 {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("create channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
return s.repo.GetByID(ctx, channel.ID)
|
||||
}
|
||||
|
||||
// GetByID 获取渠道详情
|
||||
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Update 更新渠道
|
||||
func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) {
|
||||
channel, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
||||
var oldGroupIDs []int64
|
||||
if s.authCacheInvalidator != nil {
|
||||
var err2 error
|
||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
||||
if err2 != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
// 失效新旧分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
||||
for _, gid := range oldGroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
for _, gid := range channel.GroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Delete 删除渠道
|
||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
// 先获取关联分组用于失效缓存
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||
}
|
||||
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取渠道列表
|
||||
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// modelEntry 表示一个模型模式条目(用于冲突检测)
|
||||
type modelEntry struct {
|
||||
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4")
|
||||
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
|
||||
wildcard bool
|
||||
}
|
||||
|
||||
// conflictsBetween 检查两个模型模式是否冲突
|
||||
func conflictsBetween(a, b modelEntry) bool {
|
||||
switch {
|
||||
case !a.wildcard && !b.wildcard:
|
||||
return a.prefix == b.prefix
|
||||
case a.wildcard && !b.wildcard:
|
||||
return strings.HasPrefix(b.prefix, a.prefix)
|
||||
case !a.wildcard && b.wildcard:
|
||||
return strings.HasPrefix(a.prefix, b.prefix)
|
||||
default:
|
||||
return strings.HasPrefix(a.prefix, b.prefix) ||
|
||||
strings.HasPrefix(b.prefix, a.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// toModelEntry 将模型名转换为 modelEntry
|
||||
func toModelEntry(pattern string) modelEntry {
|
||||
lower := strings.ToLower(pattern)
|
||||
isWild := strings.HasSuffix(lower, "*")
|
||||
prefix := lower
|
||||
if isWild {
|
||||
prefix = strings.TrimSuffix(lower, "*")
|
||||
}
|
||||
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
|
||||
}
|
||||
|
||||
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
|
||||
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
|
||||
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
|
||||
byPlatform := make(map[string][]modelEntry)
|
||||
for _, p := range pricingList {
|
||||
for _, model := range p.Models {
|
||||
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
|
||||
}
|
||||
}
|
||||
for platform, entries := range byPlatform {
|
||||
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
|
||||
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
|
||||
for platform, platformMapping := range mapping {
|
||||
entries := make([]modelEntry, 0, len(platformMapping))
|
||||
for src := range platformMapping {
|
||||
entries = append(entries, toModelEntry(src))
|
||||
}
|
||||
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
|
||||
for _, pricing := range pricingList {
|
||||
if err := ValidateIntervals(pricing.Intervals); err != nil {
|
||||
return infraerrors.BadRequest(
|
||||
"INVALID_PRICING_INTERVALS",
|
||||
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
|
||||
pricing.Platform, pricing.Models, err),
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
|
||||
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
|
||||
for i := 0; i < len(entries); i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if conflictsBetween(entries[i], entries[j]) {
|
||||
return infraerrors.BadRequest(errCode,
|
||||
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
|
||||
label, entries[i].pattern, entries[j].pattern, platform))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Input types ---
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
}
|
||||
2187
backend/internal/service/channel_service_test.go
Normal file
2187
backend/internal/service/channel_service_test.go
Normal file
File diff suppressed because it is too large
Load Diff
435
backend/internal/service/channel_test.go
Normal file
435
backend/internal/service/channel_test.go
Normal file
@@ -0,0 +1,435 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetModelPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
|
||||
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
}{
|
||||
{"exact match", "claude-sonnet-4", 1, false},
|
||||
{"case insensitive", "Claude-Sonnet-4", 1, false},
|
||||
{"not found", "gemini-3.1-pro", 0, true},
|
||||
{"wildcard pattern not matched", "claude-opus-4-20250514", 0, true},
|
||||
{"per_request model", "gpt-5.1", 3, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ch.GetModelPricing(tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantID, result.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := ch.GetModelPricing("claude-sonnet-4")
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Modify the returned copy's slice — original should be unchanged
|
||||
result.Models = append(result.Models, "hacked")
|
||||
|
||||
// Original should be unchanged
|
||||
require.Equal(t, 1, len(ch.ModelPricing[0].Models))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_EmptyPricing(t *testing.T) {
|
||||
ch := &Channel{ModelPricing: nil}
|
||||
require.Nil(t, ch.GetModelPricing("any-model"))
|
||||
|
||||
ch2 := &Channel{ModelPricing: []ChannelModelPricing{}}
|
||||
require.Nil(t, ch2.GetModelPricing("any-model"))
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
}{
|
||||
{"first interval", 50000, testPtrFloat64(1e-6), false},
|
||||
// (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个
|
||||
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
|
||||
// 128001 > 128000,匹配第二个区间
|
||||
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
|
||||
// (0, max] — 0 不匹配任何区间(左开)
|
||||
{"zero tokens: no match", 0, nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := p.GetIntervalForContext(tt.tokens)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext_NoMatch(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
|
||||
},
|
||||
}
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
|
||||
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext_Empty(t *testing.T) {
|
||||
p := &ChannelModelPricing{Intervals: nil}
|
||||
require.Nil(t, p.GetIntervalForContext(1000))
|
||||
}
|
||||
|
||||
func TestGetTierByLabel(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
label string
|
||||
wantNil bool
|
||||
want float64
|
||||
}{
|
||||
{"exact match", "1K", false, 0.04},
|
||||
{"case insensitive", "hd", false, 0.12},
|
||||
{"not found", "4K", true, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := p.GetTierByLabel(tt.label)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTierByLabel_Empty(t *testing.T) {
|
||||
p := &ChannelModelPricing{Intervals: nil}
|
||||
require.Nil(t, p.GetTierByLabel("1K"))
|
||||
}
|
||||
|
||||
func TestChannelClone(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
Name: "test",
|
||||
GroupIDs: []int64{10, 20},
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"model-a"},
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.Equal(t, original.ID, cloned.ID)
|
||||
require.Equal(t, original.Name, cloned.Name)
|
||||
|
||||
// Modify clone slices — original should not change
|
||||
cloned.GroupIDs[0] = 999
|
||||
require.Equal(t, int64(10), original.GroupIDs[0])
|
||||
|
||||
cloned.ModelPricing[0].Models[0] = "hacked"
|
||||
require.Equal(t, "model-a", original.ModelPricing[0].Models[0])
|
||||
}
|
||||
|
||||
func TestChannelClone_Nil(t *testing.T) {
|
||||
var ch *Channel
|
||||
require.Nil(t, ch.Clone())
|
||||
}
|
||||
|
||||
func TestChannelModelPricingClone(t *testing.T) {
|
||||
original := ChannelModelPricing{
|
||||
Models: []string{"a", "b"},
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, TierLabel: "tier1"},
|
||||
},
|
||||
}
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify clone slices — original unchanged
|
||||
cloned.Models[0] = "hacked"
|
||||
require.Equal(t, "a", original.Models[0])
|
||||
|
||||
cloned.Intervals[0].TierLabel = "hacked"
|
||||
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
|
||||
}
|
||||
|
||||
// --- BillingMode.IsValid ---
|
||||
|
||||
func TestBillingModeIsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode BillingMode
|
||||
want bool
|
||||
}{
|
||||
{"token", BillingModeToken, true},
|
||||
{"per_request", BillingModePerRequest, true},
|
||||
{"image", BillingModeImage, true},
|
||||
{"empty", BillingMode(""), true},
|
||||
{"unknown", BillingMode("unknown"), false},
|
||||
{"random", BillingMode("xyz"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.mode.IsValid())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Channel.IsActive ---
|
||||
|
||||
func TestChannelIsActive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
want bool
|
||||
}{
|
||||
{"active", StatusActive, true},
|
||||
{"disabled", "disabled", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := &Channel{Status: tt.status}
|
||||
require.Equal(t, tt.want, ch.IsActive())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ChannelModelPricing.Clone edge cases ---
|
||||
|
||||
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Models)
|
||||
})
|
||||
|
||||
t.Run("nil intervals", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Intervals: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Intervals)
|
||||
})
|
||||
|
||||
t.Run("empty models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: []string{}}
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned.Models)
|
||||
require.Empty(t, cloned.Models)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Channel.Clone edge cases ---
|
||||
|
||||
func TestChannelClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil model mapping", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelMapping: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelMapping)
|
||||
})
|
||||
|
||||
t.Run("nil model pricing", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelPricing: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelPricing)
|
||||
})
|
||||
|
||||
t.Run("deep copy model mapping", func(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
}
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify the cloned nested map
|
||||
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
|
||||
|
||||
// Original must remain unchanged
|
||||
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
|
||||
})
|
||||
}
|
||||
|
||||
// --- ValidateIntervals ---
|
||||
|
||||
func TestValidateIntervals_Empty(t *testing.T) {
|
||||
require.NoError(t, ValidateIntervals(nil))
|
||||
require.NoError(t, ValidateIntervals([]PricingInterval{}))
|
||||
}
|
||||
|
||||
func TestValidateIntervals_ValidIntervals(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
intervals []PricingInterval
|
||||
}{
|
||||
{
|
||||
name: "single bounded interval",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two intervals with gap",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two contiguous intervals",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsorted input (auto-sorted by validator)",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single unbounded interval",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.NoError(t, ValidateIntervals(tt.intervals))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIntervals_NegativeMinTokens(t *testing.T) {
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "min_tokens")
|
||||
require.Contains(t, err.Error(), ">= 0")
|
||||
}
|
||||
|
||||
func TestValidateIntervals_MaxTokensZero(t *testing.T) {
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "max_tokens")
|
||||
require.Contains(t, err.Error(), "> 0")
|
||||
}
|
||||
|
||||
func TestValidateIntervals_MaxLessThanMin(t *testing.T) {
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "max_tokens")
|
||||
require.Contains(t, err.Error(), "> min_tokens")
|
||||
}
|
||||
|
||||
func TestValidateIntervals_MaxEqualsMin(t *testing.T) {
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "max_tokens")
|
||||
require.Contains(t, err.Error(), "> min_tokens")
|
||||
}
|
||||
|
||||
func TestValidateIntervals_NegativePrice(t *testing.T) {
|
||||
negPrice := -0.01
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "input_price")
|
||||
require.Contains(t, err.Error(), ">= 0")
|
||||
}
|
||||
|
||||
func TestValidateIntervals_OverlappingIntervals(t *testing.T) {
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "overlap")
|
||||
}
|
||||
|
||||
func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
|
||||
intervals := []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)},
|
||||
}
|
||||
err := ValidateIntervals(intervals)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "unbounded")
|
||||
require.Contains(t, err.Error(), "last")
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{fallbackID},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
|
||||
fallbackID: PlatformAnthropic,
|
||||
}))
|
||||
accountRepo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range accountRepo.accounts {
|
||||
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
Hydrated: true,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
channelService: channelSvc,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(1), account.ID)
|
||||
}
|
||||
|
||||
func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{fallbackID},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
|
||||
fallbackID: PlatformAnthropic,
|
||||
}))
|
||||
accountRepo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range accountRepo.accounts {
|
||||
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
Hydrated: true,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
channelService: channelSvc,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID)
|
||||
}
|
||||
293
backend/internal/service/gateway_channel_restriction_test.go
Normal file
293
backend/internal/service/gateway_channel_restriction_test.go
Normal file
@@ -0,0 +1,293 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- billingModelForRestriction ---
|
||||
|
||||
func TestBillingModelForRestriction_Requested(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||
require.Equal(t, "claude-sonnet-4-5", got)
|
||||
}
|
||||
|
||||
func TestBillingModelForRestriction_ChannelMapped(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||
require.Equal(t, "claude-sonnet-4-6", got)
|
||||
}
|
||||
|
||||
func TestBillingModelForRestriction_Upstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||
require.Equal(t, "", got, "upstream should return empty (per-account check needed)")
|
||||
}
|
||||
|
||||
func TestBillingModelForRestriction_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||
require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped")
|
||||
}
|
||||
|
||||
// --- resolveAccountUpstreamModel ---
|
||||
|
||||
func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) {
|
||||
t.Parallel()
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
// Antigravity 平台使用 DefaultAntigravityModelMapping
|
||||
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
|
||||
require.Equal(t, "claude-sonnet-4-6", got)
|
||||
}
|
||||
|
||||
func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
got := resolveAccountUpstreamModel(account, "totally-unknown-model")
|
||||
require.Equal(t, "", got, "unsupported model should return empty")
|
||||
}
|
||||
|
||||
func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) {
|
||||
t.Parallel()
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
}
|
||||
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
|
||||
require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough")
|
||||
}
|
||||
|
||||
// --- checkChannelPricingRestriction ---
|
||||
|
||||
func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &GatewayService{channelService: &ChannelService{}}
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4"))
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &GatewayService{}
|
||||
gid := int64(10)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4"))
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &GatewayService{channelService: &ChannelService{}}
|
||||
gid := int64(10)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, ""))
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) {
|
||||
t.Parallel()
|
||||
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,但定价列表只有 claude-opus-4-6
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceChannelMapped,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(10)
|
||||
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||
"mapped model claude-sonnet-4-6 is NOT in pricing → restricted")
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,定价列表包含 claude-sonnet-4-6
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceChannelMapped,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(10)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||
"mapped model claude-sonnet-4-6 IS in pricing → allowed")
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) {
|
||||
t.Parallel()
|
||||
// billing_model_source=requested,定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceRequested,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(10)
|
||||
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||
"requested model claude-sonnet-4-5 is NOT in pricing → restricted")
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceRequested,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(10)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||
"requested model IS in pricing → allowed")
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
// upstream 模式:预检查始终跳过(返回 false),需逐账号检查
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceUpstream,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(10)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"),
|
||||
"upstream mode should skip pre-check (per-account check needed)")
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: false, // 未开启模型限制
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(10)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
|
||||
"RestrictModels=false → always allowed")
|
||||
}
|
||||
|
||||
func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) {
|
||||
t.Parallel()
|
||||
// 分组没有关联渠道
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil },
|
||||
}
|
||||
channelSvc := newTestChannelService(repo)
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
gid := int64(999)
|
||||
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
|
||||
"no channel for group → allowed")
|
||||
}
|
||||
|
||||
// --- isUpstreamModelRestrictedByChannel ---
|
||||
|
||||
func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) {
|
||||
t.Parallel()
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
// claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6
|
||||
// 但定价列表只有 claude-opus-4-6
|
||||
require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
|
||||
"upstream model claude-sonnet-4-6 NOT in pricing → restricted")
|
||||
}
|
||||
|
||||
func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
|
||||
"upstream model claude-sonnet-4-6 IS in pricing → allowed")
|
||||
}
|
||||
|
||||
func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||
},
|
||||
}
|
||||
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||
svc := &GatewayService{channelService: channelSvc}
|
||||
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
// totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空
|
||||
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"),
|
||||
"unmappable model → upstream model empty → not restricted (account filter handles this)")
|
||||
}
|
||||
@@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||
|
||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||
|
||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
|
||||
@@ -2031,7 +2031,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil, // No concurrency service
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2084,7 +2084,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil, // legacy path
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2116,7 +2116,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2148,7 +2148,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2182,7 +2182,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2218,7 +2218,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2259,7 +2259,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2287,7 +2287,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
@@ -2319,7 +2319,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2352,7 +2352,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2390,7 +2390,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.WaitPlan)
|
||||
@@ -2426,7 +2426,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2485,7 +2485,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.WaitPlan)
|
||||
@@ -2539,7 +2539,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2593,7 +2593,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2651,7 +2651,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2709,7 +2709,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.WaitPlan)
|
||||
@@ -2767,7 +2767,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2804,7 +2804,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.WaitPlan)
|
||||
@@ -2856,7 +2856,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2934,7 +2934,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -2988,7 +2988,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -3021,7 +3021,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, ErrClaudeCodeOnly)
|
||||
@@ -3059,7 +3059,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.WaitPlan)
|
||||
@@ -3097,7 +3097,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "")
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "", int64(0))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
|
||||
@@ -41,6 +41,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +60,13 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
// MediaType 媒体类型常量
|
||||
const (
|
||||
MediaTypeImage = "image"
|
||||
MediaTypeVideo = "video"
|
||||
MediaTypePrompt = "prompt"
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@@ -483,6 +490,7 @@ type ClaudeUsage struct {
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ForwardResult 转发结果
|
||||
@@ -568,6 +576,8 @@ type GatewayService struct {
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
channelService *ChannelService
|
||||
resolver *ModelPricingResolver
|
||||
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
}
|
||||
@@ -597,6 +607,8 @@ func NewGatewayService(
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
tlsFPProfileService *TLSFingerprintProfileService,
|
||||
channelService *ChannelService,
|
||||
resolver *ModelPricingResolver,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@@ -629,6 +641,8 @@ func NewGatewayService(
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
channelService: channelService,
|
||||
resolver: resolver,
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
@@ -866,17 +880,7 @@ type anthropicMetadataPayload struct {
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
// 优先使用定点修改,尽量保持客户端原始字段顺序。
|
||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
||||
return body
|
||||
}
|
||||
newBody, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
return ReplaceModelInBody(body, newModel)
|
||||
}
|
||||
|
||||
type claudeOAuthNormalizeOptions struct {
|
||||
@@ -1186,6 +1190,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
platform = PlatformAnthropic
|
||||
}
|
||||
|
||||
// Claude Code 限制可能已将 groupID 解析为 fallback group,
|
||||
// 渠道限制预检查必须使用解析后的分组。
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||
@@ -1198,8 +1211,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
|
||||
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
|
||||
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
@@ -1220,6 +1235,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
// Claude Code 限制可能已将 groupID 解析为 fallback group,
|
||||
// 渠道限制预检查必须使用解析后的分组。
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
@@ -2945,6 +2969,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
|
||||
// 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -2965,6 +2992,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -3197,6 +3227,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -3221,6 +3253,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -7410,6 +7445,8 @@ type RecordUsageInput struct {
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
@@ -7439,6 +7476,18 @@ type postUsageBillingParams struct {
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
|
||||
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
|
||||
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
|
||||
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
|
||||
}
|
||||
|
||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||
// - 订阅/余额扣费
|
||||
// - API Key 配额更新
|
||||
@@ -7468,21 +7517,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if p.shouldDeductAPIKeyQuota() {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if p.shouldUpdateRateLimits() {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
if p.shouldUpdateAccountQuota() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -7564,13 +7613,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
cmd.BalanceCost = p.Cost.ActualCost
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if p.shouldDeductAPIKeyQuota() {
|
||||
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if p.shouldUpdateRateLimits() {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
if p.shouldUpdateAccountQuota() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
@@ -7694,191 +7743,41 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
||||
}
|
||||
}
|
||||
|
||||
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||
type recordUsageOpts struct {
|
||||
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||
ParsedRequest *ParsedRequest
|
||||
|
||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||
// - Claude Max 缓存计费策略
|
||||
// - Sora 媒体类型分支(image/video/prompt)
|
||||
// - MediaType 字段写入使用日志
|
||||
EnableClaudePath bool
|
||||
|
||||
// 长上下文计费(仅 Gemini 路径需要)
|
||||
LongContextThreshold int
|
||||
LongContextMultiplier float64
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
|
||||
// 用于粘性会话切换时的特殊计费处理
|
||||
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
|
||||
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
|
||||
result.Usage.InputTokens, account.ID)
|
||||
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := 1.0
|
||||
if s.cfg != nil {
|
||||
multiplier = s.cfg.Default.RateMultiplier
|
||||
}
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
groupDefault := apiKey.Group.RateMultiplier
|
||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.MediaType == "image" || result.MediaType == "video" {
|
||||
var soraConfig *SoraPriceConfig
|
||||
if apiKey.Group != nil {
|
||||
soraConfig = &SoraPriceConfig{
|
||||
ImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
ImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
}
|
||||
}
|
||||
if result.MediaType == "image" {
|
||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
} else {
|
||||
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
} else if result.MediaType == "prompt" {
|
||||
cost = &CostBreakdown{}
|
||||
} else if result.ImageCount > 0 {
|
||||
// 图片生成计费
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = BillingTypeSubscription
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
var imageSize *string
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
var mediaType *string
|
||||
if strings.TrimSpace(result.MediaType) != "" {
|
||||
mediaType = &result.MediaType
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
MediaType: mediaType,
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
return s.recordUsageCore(ctx, &recordUsageCoreInput{
|
||||
Result: input.Result,
|
||||
APIKey: input.APIKey,
|
||||
User: input.User,
|
||||
Account: input.Account,
|
||||
Subscription: input.Subscription,
|
||||
InboundEndpoint: input.InboundEndpoint,
|
||||
UpstreamEndpoint: input.UpstreamEndpoint,
|
||||
UserAgent: input.UserAgent,
|
||||
IPAddress: input.IPAddress,
|
||||
RequestPayloadHash: input.RequestPayloadHash,
|
||||
ForceCacheBilling: input.ForceCacheBilling,
|
||||
APIKeyService: input.APIKeyService,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
EnableClaudePath: true,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
|
||||
@@ -7897,10 +7796,55 @@ type RecordUsageLongContextInput struct {
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
|
||||
return s.recordUsageCore(ctx, &recordUsageCoreInput{
|
||||
Result: input.Result,
|
||||
APIKey: input.APIKey,
|
||||
User: input.User,
|
||||
Account: input.Account,
|
||||
Subscription: input.Subscription,
|
||||
InboundEndpoint: input.InboundEndpoint,
|
||||
UpstreamEndpoint: input.UpstreamEndpoint,
|
||||
UserAgent: input.UserAgent,
|
||||
IPAddress: input.IPAddress,
|
||||
RequestPayloadHash: input.RequestPayloadHash,
|
||||
ForceCacheBilling: input.ForceCacheBilling,
|
||||
APIKeyService: input.APIKeyService,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
LongContextThreshold: input.LongContextThreshold,
|
||||
LongContextMultiplier: input.LongContextMultiplier,
|
||||
})
|
||||
}
|
||||
|
||||
// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。
|
||||
type recordUsageCoreInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
InboundEndpoint string
|
||||
UpstreamEndpoint string
|
||||
UserAgent string
|
||||
IPAddress string
|
||||
RequestPayloadHash string
|
||||
ForceCacheBilling bool
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
ChannelUsageFields
|
||||
}
|
||||
|
||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||
// opts 中的字段控制两者之间的差异行为:
|
||||
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
|
||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||
result := input.Result
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
@@ -7933,38 +7877,23 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
// 确定计费模型
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
// 图片生成计费
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
||||
billingModel = input.ChannelMappedModel
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
|
||||
// 确定 RequestedModel(渠道映射前的原始模型)
|
||||
requestedModel := result.Model
|
||||
if input.OriginalModel != "" {
|
||||
requestedModel = input.OriginalModel
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
@@ -7974,12 +7903,214 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
var imageSize *string
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
requestID := usageLog.RequestID
|
||||
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateRecordUsageCost 根据请求类型和选项计算费用。
|
||||
func (s *GatewayService) calculateRecordUsageCost(
|
||||
ctx context.Context,
|
||||
result *ForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
opts *recordUsageOpts,
|
||||
) *CostBreakdown {
|
||||
// Sora 媒体类型分支(仅 Claude 路径启用)
|
||||
if opts.EnableClaudePath {
|
||||
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
|
||||
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
|
||||
}
|
||||
if result.MediaType == MediaTypePrompt {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
}
|
||||
|
||||
// 图片生成计费
|
||||
if result.ImageCount > 0 {
|
||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
||||
}
|
||||
|
||||
// Token 计费
|
||||
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
}
|
||||
|
||||
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
|
||||
func (s *GatewayService) calculateSoraMediaCost(
|
||||
result *ForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
) *CostBreakdown {
|
||||
var soraConfig *SoraPriceConfig
|
||||
if apiKey.Group != nil {
|
||||
soraConfig = &SoraPriceConfig{
|
||||
ImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
ImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
}
|
||||
}
|
||||
if result.MediaType == MediaTypeImage {
|
||||
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
}
|
||||
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
|
||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||
if s.resolver == nil || apiKey.Group == nil {
|
||||
return nil
|
||||
}
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
return resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
|
||||
func (s *GatewayService) calculateImageCost(
|
||||
ctx context.Context,
|
||||
result *ForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
) *CostBreakdown {
|
||||
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
gid := apiKey.Group.ID
|
||||
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
Resolved: resolved,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
|
||||
return &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
return cost
|
||||
}
|
||||
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
}
|
||||
|
||||
// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。
|
||||
func (s *GatewayService) calculateTokenCost(
|
||||
ctx context.Context,
|
||||
result *ForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
opts *recordUsageOpts,
|
||||
) *CostBreakdown {
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
|
||||
// 优先尝试渠道定价 → CalculateCostUnified
|
||||
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
|
||||
gid := apiKey.Group.ID
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
Resolved: resolved,
|
||||
})
|
||||
} else if opts.LongContextThreshold > 0 {
|
||||
// 长上下文双倍计费(如 Gemini 200K 阈值)
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(
|
||||
billingModel, tokens, multiplier,
|
||||
opts.LongContextThreshold, opts.LongContextMultiplier,
|
||||
)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
return &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
return cost
|
||||
}
|
||||
|
||||
// buildRecordUsageLog 构建使用日志并设置计费模式。
|
||||
func (s *GatewayService) buildRecordUsageLog(
|
||||
ctx context.Context,
|
||||
input *recordUsageCoreInput,
|
||||
result *ForwardResult,
|
||||
apiKey *APIKey,
|
||||
user *User,
|
||||
account *Account,
|
||||
subscription *UserSubscription,
|
||||
requestedModel string,
|
||||
multiplier float64,
|
||||
accountRateMultiplier float64,
|
||||
billingType int8,
|
||||
cacheTTLOverridden bool,
|
||||
cost *CostBreakdown,
|
||||
opts *recordUsageOpts,
|
||||
) *UsageLog {
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
@@ -7987,7 +8118,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
RequestedModel: requestedModel,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
@@ -7998,72 +8129,170 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
BillingMode: resolveBillingMode(opts, result, cost),
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||
MediaType: resolveMediaType(opts, result),
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
|
||||
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
|
||||
GroupID: apiKey.GroupID,
|
||||
SubscriptionID: optionalSubscriptionID(subscription),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
if cost != nil {
|
||||
usageLog.InputCost = cost.InputCost
|
||||
usageLog.OutputCost = cost.OutputCost
|
||||
usageLog.ImageOutputCost = cost.ImageOutputCost
|
||||
usageLog.CacheCreationCost = cost.CacheCreationCost
|
||||
usageLog.CacheReadCost = cost.CacheReadCost
|
||||
usageLog.TotalCost = cost.TotalCost
|
||||
usageLog.ActualCost = cost.ActualCost
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
return usageLog
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
|
||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
||||
isSoraMedia := opts.EnableClaudePath &&
|
||||
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
||||
if isSoraMedia {
|
||||
return nil
|
||||
}
|
||||
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
var mode string
|
||||
switch {
|
||||
case cost != nil && cost.BillingMode != "":
|
||||
mode = cost.BillingMode
|
||||
case result.ImageCount > 0:
|
||||
mode = string(BillingModeImage)
|
||||
default:
|
||||
mode = string(BillingModeToken)
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
return &mode
|
||||
}
|
||||
|
||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
||||
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
||||
return &result.MediaType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||
if subscription != nil {
|
||||
return &subscription.ID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveChannelMapping 委托渠道服务解析模型映射
|
||||
func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||
if s.channelService == nil {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
|
||||
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
return ReplaceModelInBody(body, newModel)
|
||||
}
|
||||
|
||||
// IsModelRestricted 检查模型是否被渠道限制
|
||||
func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射。
|
||||
// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。
|
||||
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||
if s.channelService == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
|
||||
// 供调度阶段预检查(requested / channel_mapped)。
|
||||
// upstream 需逐账号检查,此处返回 false。
|
||||
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||
return false
|
||||
}
|
||||
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
|
||||
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
|
||||
if billingModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
|
||||
}
|
||||
|
||||
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
|
||||
// upstream 返回空(需逐账号检查)。
|
||||
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
|
||||
switch source {
|
||||
case BillingModelSourceRequested:
|
||||
return requestedModel
|
||||
case BillingModelSourceUpstream:
|
||||
return ""
|
||||
case BillingModelSourceChannelMapped:
|
||||
return channelMappedModel
|
||||
default:
|
||||
return channelMappedModel
|
||||
}
|
||||
}
|
||||
|
||||
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
|
||||
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
|
||||
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
|
||||
}
|
||||
|
||||
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
|
||||
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return mapAntigravityModel(account, requestedModel)
|
||||
}
|
||||
return account.GetMappedModel(requestedModel)
|
||||
}
|
||||
|
||||
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
|
||||
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
|
||||
return false
|
||||
}
|
||||
if ch == nil || !ch.RestrictModels {
|
||||
return false
|
||||
}
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
|
||||
cand := int(usage.Get("candidatesTokenCount").Int())
|
||||
cached := int(usage.Get("cachedContentTokenCount").Int())
|
||||
thoughts := int(usage.Get("thoughtsTokenCount").Int())
|
||||
|
||||
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
|
||||
imageTokens := 0
|
||||
candidateDetails := usage.Get("candidatesTokensDetails")
|
||||
if candidateDetails.Exists() {
|
||||
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
|
||||
if detail.Get("modality").String() == "IMAGE" {
|
||||
imageTokens = int(detail.Get("tokenCount").Int())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
return &ClaudeUsage{
|
||||
InputTokens: prompt - cached,
|
||||
OutputTokens: cand + thoughts,
|
||||
CacheReadInputTokens: cached,
|
||||
ImageOutputTokens: imageTokens,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
231
backend/internal/service/model_pricing_resolver.go
Normal file
231
backend/internal/service/model_pricing_resolver.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// PricingSource 定价来源标识
|
||||
const (
|
||||
PricingSourceChannel = "channel"
|
||||
PricingSourceLiteLLM = "litellm"
|
||||
PricingSourceFallback = "fallback"
|
||||
)
|
||||
|
||||
// ResolvedPricing 统一定价解析结果
|
||||
type ResolvedPricing struct {
|
||||
// Mode 计费模式
|
||||
Mode BillingMode
|
||||
|
||||
// Token 模式:基础定价(来自 LiteLLM 或 fallback)
|
||||
BasePricing *ModelPricing
|
||||
|
||||
// Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段)
|
||||
Intervals []PricingInterval
|
||||
|
||||
// 按次/图片模式:分层定价
|
||||
RequestTiers []PricingInterval
|
||||
|
||||
// 按次/图片模式:默认价格(未命中层级时使用)
|
||||
DefaultPerRequestPrice float64
|
||||
|
||||
// 来源标识
|
||||
Source string // "channel", "litellm", "fallback"
|
||||
|
||||
// 是否支持缓存细分
|
||||
SupportsCacheBreakdown bool
|
||||
}
|
||||
|
||||
// ModelPricingResolver 统一模型定价解析器。
|
||||
// 解析链:Channel → LiteLLM → Fallback。
|
||||
type ModelPricingResolver struct {
|
||||
channelService *ChannelService
|
||||
billingService *BillingService
|
||||
}
|
||||
|
||||
// NewModelPricingResolver 创建定价解析器实例
|
||||
func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver {
|
||||
return &ModelPricingResolver{
|
||||
channelService: channelService,
|
||||
billingService: billingService,
|
||||
}
|
||||
}
|
||||
|
||||
// PricingInput 定价解析输入
|
||||
type PricingInput struct {
|
||||
Model string
|
||||
GroupID *int64 // nil 表示不检查渠道
|
||||
}
|
||||
|
||||
// Resolve 解析模型定价。
|
||||
// 1. 获取基础定价(LiteLLM → Fallback)
|
||||
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
|
||||
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
|
||||
// 1. 获取基础定价
|
||||
basePricing, source := r.resolveBasePricing(input.Model)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Source: source,
|
||||
SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown,
|
||||
}
|
||||
|
||||
// 2. 如果有 GroupID,尝试渠道覆盖
|
||||
if input.GroupID != nil {
|
||||
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价
|
||||
func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) {
|
||||
pricing, err := r.billingService.GetModelPricing(model)
|
||||
if err != nil {
|
||||
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
|
||||
"model", model, "error", err)
|
||||
return nil, PricingSourceFallback
|
||||
}
|
||||
return pricing, PricingSourceLiteLLM
|
||||
}
|
||||
|
||||
// applyChannelOverrides 应用渠道定价覆盖
|
||||
func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) {
|
||||
chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model)
|
||||
if chPricing == nil {
|
||||
return
|
||||
}
|
||||
|
||||
resolved.Source = PricingSourceChannel
|
||||
resolved.Mode = chPricing.BillingMode
|
||||
if resolved.Mode == "" {
|
||||
resolved.Mode = BillingModeToken
|
||||
}
|
||||
|
||||
switch resolved.Mode {
|
||||
case BillingModeToken:
|
||||
r.applyTokenOverrides(chPricing, resolved)
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
r.applyRequestTierOverrides(chPricing, resolved)
|
||||
}
|
||||
}
|
||||
|
||||
// applyTokenOverrides 应用 token 模式的渠道覆盖
|
||||
func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
|
||||
// 过滤掉所有价格字段都为空的无效 interval
|
||||
validIntervals := filterValidIntervals(chPricing.Intervals)
|
||||
|
||||
// 如果有有效的区间定价,使用区间
|
||||
if len(validIntervals) > 0 {
|
||||
resolved.Intervals = validIntervals
|
||||
return
|
||||
}
|
||||
|
||||
// 否则用 flat 字段覆盖 BasePricing
|
||||
if resolved.BasePricing == nil {
|
||||
resolved.BasePricing = &ModelPricing{}
|
||||
}
|
||||
|
||||
if chPricing.InputPrice != nil {
|
||||
resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice
|
||||
resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice
|
||||
}
|
||||
if chPricing.OutputPrice != nil {
|
||||
resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice
|
||||
resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice
|
||||
}
|
||||
if chPricing.CacheWritePrice != nil {
|
||||
resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice
|
||||
resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice
|
||||
resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice
|
||||
}
|
||||
if chPricing.CacheReadPrice != nil {
|
||||
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
|
||||
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
|
||||
}
|
||||
if chPricing.ImageOutputPrice != nil {
|
||||
resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
|
||||
}
|
||||
}
|
||||
|
||||
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
|
||||
func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
|
||||
resolved.RequestTiers = filterValidIntervals(chPricing.Intervals)
|
||||
if chPricing.PerRequestPrice != nil {
|
||||
resolved.DefaultPerRequestPrice = *chPricing.PerRequestPrice
|
||||
}
|
||||
}
|
||||
|
||||
// filterValidIntervals 过滤掉所有价格字段都为空的无效 interval。
|
||||
// 前端可能创建了只有 min/max 但无价格的空 interval。
|
||||
func filterValidIntervals(intervals []PricingInterval) []PricingInterval {
|
||||
var valid []PricingInterval
|
||||
for _, iv := range intervals {
|
||||
if iv.InputPrice != nil || iv.OutputPrice != nil ||
|
||||
iv.CacheWritePrice != nil || iv.CacheReadPrice != nil ||
|
||||
iv.PerRequestPrice != nil {
|
||||
valid = append(valid, iv)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// GetIntervalPricing 根据 context token 数获取区间定价。
|
||||
// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。
|
||||
func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing {
|
||||
if len(resolved.Intervals) == 0 {
|
||||
return resolved.BasePricing
|
||||
}
|
||||
|
||||
iv := FindMatchingInterval(resolved.Intervals, totalContextTokens)
|
||||
if iv == nil {
|
||||
return resolved.BasePricing
|
||||
}
|
||||
|
||||
return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown)
|
||||
}
|
||||
|
||||
// intervalToModelPricing 将区间定价转换为 ModelPricing
|
||||
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing {
|
||||
pricing := &ModelPricing{
|
||||
SupportsCacheBreakdown: supportsCacheBreakdown,
|
||||
}
|
||||
if iv.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *iv.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *iv.InputPrice
|
||||
}
|
||||
if iv.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *iv.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *iv.OutputPrice
|
||||
}
|
||||
if iv.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *iv.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *iv.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *iv.CacheWritePrice
|
||||
}
|
||||
if iv.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *iv.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice
|
||||
}
|
||||
return pricing
|
||||
}
|
||||
|
||||
// GetRequestTierPrice 根据层级标签获取按次价格
|
||||
func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 {
|
||||
for _, tier := range resolved.RequestTiers {
|
||||
if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil {
|
||||
return *tier.PerRequestPrice
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetRequestTierPriceByContext 根据 context token 数获取按次价格
|
||||
func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 {
|
||||
iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens)
|
||||
if iv != nil && iv.PerRequestPrice != nil {
|
||||
return *iv.PerRequestPrice
|
||||
}
|
||||
return 0
|
||||
}
|
||||
663
backend/internal/service/model_pricing_resolver_test.go
Normal file
663
backend/internal/service/model_pricing_resolver_test.go
Normal file
@@ -0,0 +1,663 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestBillingServiceForResolver() *BillingService {
|
||||
bs := &BillingService{
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6,
|
||||
OutputPricePerToken: 15e-6,
|
||||
CacheCreationPricePerToken: 3.75e-6,
|
||||
CacheReadPricePerToken: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
return bs
|
||||
}
|
||||
|
||||
func TestResolve_NoGroupID(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: nil,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
|
||||
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
|
||||
require.Equal(t, "litellm", resolved.Source)
|
||||
}
|
||||
|
||||
func TestResolve_UnknownModel(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "unknown-model-xyz",
|
||||
GroupID: nil,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Nil(t, resolved.BasePricing)
|
||||
// Unknown model: GetModelPricing returns error, source is "fallback"
|
||||
require.Equal(t, "fallback", resolved.Source)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_NoIntervals(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
basePricing := &ModelPricing{InputPricePerToken: 5e-6}
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: nil,
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 50000)
|
||||
require.Equal(t, basePricing, result)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
|
||||
SupportsCacheBreakdown: true,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12)
|
||||
require.True(t, result.SupportsCacheBreakdown)
|
||||
|
||||
result2 := r.GetIntervalPricing(resolved, 200000)
|
||||
require.NotNil(t, result2)
|
||||
require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
basePricing := &ModelPricing{InputPricePerToken: 99e-6}
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 5000)
|
||||
require.Equal(t, basePricing, result)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPrice(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPriceByContext(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
||||
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: nil},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Channel override tests — exercises applyChannelOverrides via Resolve
|
||||
// ===========================================================================
|
||||
|
||||
// helper: creates a resolver wired to a ChannelService that returns the given
|
||||
// channel (active, groupID=100, platform=anthropic) with the specified pricing.
|
||||
func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver {
|
||||
t.Helper()
|
||||
const groupID = 100
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||||
return []Channel{{
|
||||
ID: 1,
|
||||
Name: "test-channel",
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelPricing: pricing,
|
||||
}}, nil
|
||||
},
|
||||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||||
return map[int64]string{groupID: "anthropic"}, nil
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
return NewModelPricingResolver(cs, bs)
|
||||
}
|
||||
|
||||
// groupIDPtr returns a pointer to groupID 100 (the test constant).
|
||||
func groupIDPtr() *int64 { v := int64(100); return &v }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Token mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(10e-6),
|
||||
OutputPrice: testPtrFloat64(50e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) {
|
||||
// Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback).
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(20e-6),
|
||||
// OutputPrice intentionally nil
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
// InputPrice overridden by channel
|
||||
require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
// OutputPrice kept from base (fallback: 15e-6)
|
||||
require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.Len(t, resolved.Intervals, 2)
|
||||
|
||||
// GetIntervalPricing should use channel intervals
|
||||
iv := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, iv)
|
||||
require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12)
|
||||
|
||||
iv2 := r.GetIntervalPricing(resolved, 200000)
|
||||
require.NotNil(t, iv2)
|
||||
require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) {
|
||||
// Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"unknown-model-xyz"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(7e-6),
|
||||
OutputPrice: testPtrFloat64(21e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "unknown-model-xyz",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
// BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Per-request mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_PerRequest(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModePerRequest, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 2)
|
||||
|
||||
// Verify tier lookups
|
||||
require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
||||
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) {
|
||||
// PerRequestPrice nil → DefaultPerRequestPrice stays 0.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModePerRequest,
|
||||
// PerRequestPrice intentionally nil
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModePerRequest, resolved.Mode)
|
||||
require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. Image mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_Image(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.08),
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeImage, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 3)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeImage,
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
||||
require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. Source tracking & default mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(1e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) {
|
||||
// Channel pricing with empty BillingMode → defaults to BillingModeToken.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: "", // intentionally empty
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 5. GetIntervalPricing integration after channel override
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) {
|
||||
// Channel provides intervals that override the base pricing path.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)},
|
||||
{MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
// Token count 50000 matches first interval
|
||||
pricing := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
|
||||
// Token count 150000 matches second interval
|
||||
pricing2 := r.GetIntervalPricing(resolved, 150000)
|
||||
require.NotNil(t, pricing2)
|
||||
require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) {
|
||||
// Channel intervals don't match token count → falls back to BasePricing.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
// Only covers tokens > 50000
|
||||
{MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
// Token count 1000 doesn't match any interval (1000 <= 50000 minTokens)
|
||||
pricing := r.GetIntervalPricing(resolved, 1000)
|
||||
// Should fall back to BasePricing (from the billing service fallback)
|
||||
require.NotNil(t, pricing)
|
||||
require.Equal(t, resolved.BasePricing, pricing)
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 6. Error path tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
|
||||
// When ListAll returns an error, the ChannelService cache build fails.
|
||||
// Resolve should gracefully fall back to base pricing without panicking.
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||||
return nil, errors.New("database unavailable")
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(cs, bs)
|
||||
|
||||
gid := int64(100)
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: &gid,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
// Should NOT panic, should NOT have source "channel"
|
||||
require.NotEqual(t, "channel", resolved.Source)
|
||||
// Base pricing should still be present (from BillingService fallback)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 7. GetRequestTierPriceByContext boundary tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: nil, // empty
|
||||
}
|
||||
|
||||
price := r.GetRequestTierPriceByContext(resolved, 50000)
|
||||
require.InDelta(t, 0.0, price, 1e-12)
|
||||
|
||||
// Also test with explicit empty slice
|
||||
resolved2 := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{},
|
||||
}
|
||||
|
||||
price2 := r.GetRequestTierPriceByContext(resolved2, 50000)
|
||||
require.InDelta(t, 0.0, price2, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
// totalContextTokens = 128000 exactly:
|
||||
// FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens
|
||||
// For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval
|
||||
price := r.GetRequestTierPriceByContext(resolved, 128000)
|
||||
require.InDelta(t, 0.05, price, 1e-12)
|
||||
|
||||
// totalContextTokens = 128001 should match second interval
|
||||
// For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match
|
||||
// For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches
|
||||
price2 := r.GetRequestTierPriceByContext(resolved, 128001)
|
||||
require.InDelta(t, 0.10, price2, 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 8. filterValidIntervals
|
||||
// ===========================================================================
|
||||
|
||||
func TestFilterValidIntervals(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
intervals []PricingInterval
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "empty list",
|
||||
intervals: nil,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "all-nil interval filtered out",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000)},
|
||||
},
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "interval with only InputPrice kept",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "interval with only OutputPrice kept",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), OutputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "interval with only CacheWritePrice kept",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, CacheWritePrice: testPtrFloat64(3e-6)},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "interval with only CacheReadPrice kept",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, CacheReadPrice: testPtrFloat64(0.5e-6)},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "interval with only PerRequestPrice kept",
|
||||
intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil}, // all-nil → filtered out
|
||||
{MinTokens: 256000, OutputPrice: testPtrFloat64(5e-6)},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterValidIntervals(tt.intervals)
|
||||
require.Len(t, result, tt.wantLen)
|
||||
})
|
||||
}
|
||||
}
|
||||
140
backend/internal/service/openai_channel_restriction_test.go
Normal file
140
backend/internal/service/openai_channel_restriction_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
channelSvc := newTestChannelService(makeStandardRepo(Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceChannelMapped,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: PlatformOpenAI, Models: []string{"gpt-4o"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
PlatformOpenAI: {"gpt-4.1": "o3-mini"},
|
||||
},
|
||||
}, map[int64]string{10: PlatformOpenAI}))
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true},
|
||||
}},
|
||||
channelService: channelSvc,
|
||||
}
|
||||
|
||||
groupID := int64(10)
|
||||
_, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
require.Contains(t, err.Error(), "channel pricing restriction")
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
channelSvc := newTestChannelService(makeStandardRepo(Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceUpstream,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
|
||||
},
|
||||
}, map[int64]string{10: PlatformOpenAI}))
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 10,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 20,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
channelService: channelSvc,
|
||||
}
|
||||
|
||||
groupID := int64(10)
|
||||
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(2), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
channelSvc := newTestChannelService(makeStandardRepo(Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: BillingModelSourceUpstream,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
|
||||
},
|
||||
}, map[int64]string{10: PlatformOpenAI}))
|
||||
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:sticky-session": 1},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 10,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 20,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
channelService: channelSvc,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
groupID := int64(10)
|
||||
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(2), account.ID)
|
||||
require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"])
|
||||
require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"])
|
||||
}
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
const compatPromptCacheKeyPrefix = "compat_cc_"
|
||||
|
||||
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
||||
switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) {
|
||||
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
|
||||
switch normalizeCodexModel(strings.TrimSpace(model)) {
|
||||
case "gpt-5.4", "gpt-5.3-codex":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))
|
||||
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
|
||||
if normalizedModel == "" {
|
||||
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))
|
||||
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
|
||||
}
|
||||
if normalizedModel == "" {
|
||||
normalizedModel = strings.TrimSpace(req.Model)
|
||||
|
||||
@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
||||
// derive a stable seed from the final upstream model family.
|
||||
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
|
||||
upstreamModel := normalizeCodexModel(billingModel)
|
||||
|
||||
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||||
compatPromptCacheInjected := false
|
||||
|
||||
@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
|
||||
// 3. Model mapping
|
||||
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
|
||||
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
|
||||
upstreamModel := normalizeCodexModel(billingModel)
|
||||
responsesReq.Model = upstreamModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
|
||||
@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sort"
|
||||
@@ -204,6 +205,7 @@ type OpenAIUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
@@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
openaiWSResolver OpenAIWSProtocolResolver
|
||||
resolver *ModelPricingResolver
|
||||
channelService *ChannelService
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
resolver *ModelPricingResolver,
|
||||
channelService *ChannelService,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
resolver: resolver,
|
||||
channelService: channelService,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
|
||||
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||
if s.channelService == nil {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
|
||||
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射。
|
||||
// 模型限制检查已移至调度阶段,restricted 始终返回 false。
|
||||
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||
if s.channelService == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||
return false
|
||||
}
|
||||
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
|
||||
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
|
||||
if billingModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
|
||||
if upstreamModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
|
||||
return false
|
||||
}
|
||||
if ch == nil || !ch.RestrictModels {
|
||||
return false
|
||||
}
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
|
||||
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
return ReplaceModelInBody(body, newModel)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||
if s != nil && s.codexSnapshotThrottle != nil {
|
||||
return s.codexSnapshotThrottle
|
||||
@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 1. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
|
||||
@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
||||
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
|
||||
selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
return nil
|
||||
}
|
||||
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
|
||||
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
|
||||
|
||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
|
||||
@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
||||
if model, ok := reqBody["model"].(string); ok {
|
||||
upstreamModel = resolveOpenAIUpstreamModel(model)
|
||||
upstreamModel = normalizeCodexModel(model)
|
||||
if upstreamModel != "" && upstreamModel != model {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||
model, upstreamModel, account.Name, account.Type, isCodexCLI)
|
||||
@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
ChannelUsageFields
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
|
||||
// Get rate multiplier
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
multiplier := 1.0
|
||||
if s.cfg != nil {
|
||||
multiplier = s.cfg.Default.RateMultiplier
|
||||
}
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
resolver := s.userGroupRateResolver
|
||||
if resolver == nil {
|
||||
@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if result.BillingModel != "" {
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
||||
billingModel = input.ChannelMappedModel
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
ServiceTier: serviceTier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
}
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
@@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
|
||||
// 确定 RequestedModel(渠道映射前的原始模型)
|
||||
requestedModel := result.Model
|
||||
if input.OriginalModel != "" {
|
||||
requestedModel = input.OriginalModel
|
||||
}
|
||||
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
OpenAIWSMode: result.OpenAIWSMode,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: requestedModel,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
if cost != nil {
|
||||
usageLog.InputCost = cost.InputCost
|
||||
usageLog.OutputCost = cost.OutputCost
|
||||
usageLog.ImageOutputCost = cost.ImageOutputCost
|
||||
usageLog.CacheCreationCost = cost.CacheCreationCost
|
||||
usageLog.CacheReadCost = cost.CacheReadCost
|
||||
usageLog.TotalCost = cost.TotalCost
|
||||
usageLog.ActualCost = cost.ActualCost
|
||||
}
|
||||
usageLog.RateMultiplier = multiplier
|
||||
usageLog.AccountRateMultiplier = &accountRateMultiplier
|
||||
usageLog.BillingType = billingType
|
||||
usageLog.Stream = result.Stream
|
||||
usageLog.OpenAIWSMode = result.OpenAIWSMode
|
||||
usageLog.DurationMs = &durationMs
|
||||
usageLog.FirstTokenMs = result.FirstTokenMs
|
||||
usageLog.CreatedAt = time.Now()
|
||||
// 设置渠道信息
|
||||
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
|
||||
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
|
||||
// 设置计费模式
|
||||
if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := string(BillingModeToken)
|
||||
usageLog.BillingMode = &billingMode
|
||||
}
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// resolveOpenAIForwardModel resolves the account/group mapping result for
|
||||
// OpenAI-compatible forwarding. Group-level default mapping only applies when
|
||||
// the account itself did not match any explicit model_mapping rule.
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||
if account == nil {
|
||||
if defaultMappedModel != "" {
|
||||
@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
func resolveOpenAIUpstreamModel(model string) string {
|
||||
if isBareGPT53CodexSparkModel(model) {
|
||||
return "gpt-5.3-codex-spark"
|
||||
}
|
||||
return normalizeCodexModel(strings.TrimSpace(model))
|
||||
}
|
||||
|
||||
func isBareGPT53CodexSparkModel(model string) bool {
|
||||
modelID := strings.TrimSpace(model)
|
||||
if modelID == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(modelID))
|
||||
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
|
||||
}
|
||||
|
||||
@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
|
||||
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
|
||||
if withoutDefault != "gpt-5.1" {
|
||||
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
|
||||
}
|
||||
|
||||
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
|
||||
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
|
||||
if withDefault != "gpt-5.4" {
|
||||
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIUpstreamModel(t *testing.T) {
|
||||
func TestNormalizeCodexModel(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
|
||||
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
if got := resolveOpenAIUpstreamModel(input); got != expected {
|
||||
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected)
|
||||
if got := normalizeCodexModel(input); got != expected {
|
||||
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
normalized = next
|
||||
}
|
||||
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
|
||||
upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel))
|
||||
if upstreamModel != originalModel {
|
||||
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
||||
if setErr != nil {
|
||||
@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
mappedModel := ""
|
||||
var mappedModelBytes []byte
|
||||
if originalModel != "" {
|
||||
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
|
||||
mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel))
|
||||
needModelReplace = mappedModel != "" && mappedModel != originalModel
|
||||
if needModelReplace {
|
||||
mappedModelBytes = []byte(mappedModel)
|
||||
|
||||
@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||
|
||||
@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
||||
if s.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gateway service not available")
|
||||
}
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||
}
|
||||
|
||||
@@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
|
||||
}
|
||||
|
||||
// PricingService 动态价格服务
|
||||
@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.OutputCostPerImage != nil {
|
||||
pricing.OutputCostPerImage = *entry.OutputCostPerImage
|
||||
}
|
||||
if entry.OutputCostPerImageToken != nil {
|
||||
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
|
||||
}
|
||||
|
||||
result[modelName] = pricing
|
||||
}
|
||||
|
||||
15
backend/internal/service/testhelpers_test.go
Normal file
15
backend/internal/service/testhelpers_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
// testPtrFloat64 returns a pointer to the given float64 value.
|
||||
func testPtrFloat64(v float64) *float64 { return &v }
|
||||
|
||||
// testPtrInt returns a pointer to the given int value.
|
||||
func testPtrInt(v int) *int { return &v }
|
||||
|
||||
// testPtrString returns a pointer to the given string value.
|
||||
func testPtrString(v string) *string { return &v }
|
||||
|
||||
// testPtrBool returns a pointer to the given bool value.
|
||||
func testPtrBool(v bool) *bool { return &v }
|
||||
@@ -104,6 +104,14 @@ type UsageLog struct {
|
||||
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||
// Nil means no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string
|
||||
// ChannelID 渠道 ID
|
||||
ChannelID *int64
|
||||
// ModelMappingChain 模型映射链,如 "a→b→c"
|
||||
ModelMappingChain *string
|
||||
// BillingTier 计费层级标签(per_request/image 模式)
|
||||
BillingTier *string
|
||||
// BillingMode 计费模式:token/image(sora 路径为 nil)
|
||||
BillingMode *string
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
@@ -126,6 +134,9 @@ type UsageLog struct {
|
||||
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
|
||||
|
||||
ImageOutputTokens int
|
||||
ImageOutputCost float64
|
||||
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
CacheCreationCost float64
|
||||
|
||||
@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||
}
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
}
|
||||
|
||||
func optionalInt64Ptr(v int64) *int64 {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideScheduledTestService,
|
||||
ProvideScheduledTestRunnerService,
|
||||
NewGroupCapacityService,
|
||||
NewChannelService,
|
||||
NewModelPricingResolver,
|
||||
)
|
||||
|
||||
56
backend/migrations/081_create_channels.sql
Normal file
56
backend/migrations/081_create_channels.sql
Normal file
@@ -0,0 +1,56 @@
|
||||
-- Create channels table for managing pricing channels.
|
||||
-- A channel groups multiple groups together and provides custom model pricing.
|
||||
|
||||
SET LOCAL lock_timeout = '5s';
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
-- 渠道表
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 渠道名称唯一索引
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
|
||||
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
|
||||
|
||||
-- 渠道-分组关联表(每个分组只能属于一个渠道)
|
||||
CREATE TABLE IF NOT EXISTS channel_groups (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
|
||||
|
||||
-- 渠道模型定价表(一条定价可绑定多个模型)
|
||||
CREATE TABLE IF NOT EXISTS channel_model_pricing (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
models JSONB NOT NULL DEFAULT '[]',
|
||||
input_price NUMERIC(20,12),
|
||||
output_price NUMERIC(20,12),
|
||||
cache_write_price NUMERIC(20,12),
|
||||
cache_read_price NUMERIC(20,12),
|
||||
image_output_price NUMERIC(20,8),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
|
||||
|
||||
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
|
||||
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
|
||||
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
|
||||
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
|
||||
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认';
|
||||
67
backend/migrations/082_refactor_channel_pricing.sql
Normal file
67
backend/migrations/082_refactor_channel_pricing.sql
Normal file
@@ -0,0 +1,67 @@
|
||||
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
|
||||
-- Supports three billing modes: token (per-token with context intervals),
|
||||
-- per_request (per-request with context-size tiers), and image (per-image).
|
||||
|
||||
SET LOCAL lock_timeout = '5s';
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
-- 1. 为 channel_model_pricing 添加 billing_mode 列
|
||||
ALTER TABLE channel_model_pricing
|
||||
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
|
||||
|
||||
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)';
|
||||
|
||||
-- 2. 创建区间定价子表
|
||||
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
|
||||
min_tokens INT NOT NULL DEFAULT 0,
|
||||
max_tokens INT,
|
||||
tier_label VARCHAR(50),
|
||||
input_price NUMERIC(20,12),
|
||||
output_price NUMERIC(20,12),
|
||||
cache_write_price NUMERIC(20,12),
|
||||
cache_read_price NUMERIC(20,12),
|
||||
per_request_price NUMERIC(20,12),
|
||||
sort_order INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
|
||||
ON channel_pricing_intervals (pricing_id);
|
||||
|
||||
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
|
||||
|
||||
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
|
||||
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
|
||||
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
|
||||
SELECT
|
||||
cmp.id,
|
||||
0,
|
||||
NULL,
|
||||
cmp.input_price,
|
||||
cmp.output_price,
|
||||
cmp.cache_write_price,
|
||||
cmp.cache_read_price,
|
||||
0
|
||||
FROM channel_model_pricing cmp
|
||||
WHERE cmp.billing_mode = 'token'
|
||||
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
|
||||
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
|
||||
);
|
||||
|
||||
-- 4. 迁移 image_output_price 为 image 模式的区间条目
|
||||
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
|
||||
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
|
||||
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
|
||||
5
backend/migrations/083_channel_model_mapping.sql
Normal file
5
backend/migrations/083_channel_model_mapping.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
SET LOCAL lock_timeout = '5s';
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
|
||||
COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';
|
||||
7
backend/migrations/084_channel_billing_model_source.sql
Normal file
7
backend/migrations/084_channel_billing_model_source.sql
Normal file
@@ -0,0 +1,7 @@
|
||||
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
|
||||
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
|
||||
|
||||
-- Add channel tracking fields to usage_logs
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Add model restriction switch to channels
|
||||
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
|
||||
|
||||
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
|
||||
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);
|
||||
21
backend/migrations/086_channel_platform_pricing.sql
Normal file
21
backend/migrations/086_channel_platform_pricing.sql
Normal file
@@ -0,0 +1,21 @@
|
||||
-- 086_channel_platform_pricing.sql
|
||||
-- 渠道按平台维度:model_pricing 加 platform 列,model_mapping 改为嵌套格式
|
||||
|
||||
-- 1. channel_model_pricing 加 platform 列
|
||||
ALTER TABLE channel_model_pricing
|
||||
ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
|
||||
ON channel_model_pricing (platform);
|
||||
|
||||
-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
|
||||
-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
|
||||
UPDATE channels
|
||||
SET model_mapping = jsonb_build_object('anthropic', model_mapping)
|
||||
WHERE model_mapping IS NOT NULL
|
||||
AND model_mapping::text NOT IN ('{}', 'null', '')
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM jsonb_each(model_mapping) AS kv
|
||||
WHERE jsonb_typeof(kv.value) = 'object'
|
||||
LIMIT 1
|
||||
);
|
||||
2
backend/migrations/087_usage_log_billing_mode.sql
Normal file
2
backend/migrations/087_usage_log_billing_mode.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- Add billing_mode to usage_logs (records the billing mode: token/per_request/image)
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20);
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Change default billing_model_source for new channels to 'channel_mapped'
|
||||
-- Existing channels keep their current setting (no UPDATE on existing rows)
|
||||
ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped';
|
||||
2
backend/migrations/089_usage_log_image_output_tokens.sql
Normal file
2
backend/migrations/089_usage_log_image_output_tokens.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;
|
||||
148
frontend/src/api/admin/channels.ts
Normal file
148
frontend/src/api/admin/channels.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
/**
|
||||
* Admin Channels API endpoints
|
||||
* Handles channel management for administrators
|
||||
*/
|
||||
|
||||
import { apiClient } from '../client'
|
||||
|
||||
export type BillingMode = 'token' | 'per_request' | 'image'
|
||||
|
||||
export interface PricingInterval {
|
||||
id?: number
|
||||
min_tokens: number
|
||||
max_tokens: number | null
|
||||
tier_label: string
|
||||
input_price: number | null
|
||||
output_price: number | null
|
||||
cache_write_price: number | null
|
||||
cache_read_price: number | null
|
||||
per_request_price: number | null
|
||||
sort_order: number
|
||||
}
|
||||
|
||||
export interface ChannelModelPricing {
|
||||
id?: number
|
||||
platform: string
|
||||
models: string[]
|
||||
billing_mode: BillingMode
|
||||
input_price: number | null
|
||||
output_price: number | null
|
||||
cache_write_price: number | null
|
||||
cache_read_price: number | null
|
||||
image_output_price: number | null
|
||||
per_request_price: number | null
|
||||
intervals: PricingInterval[]
|
||||
}
|
||||
|
||||
export interface Channel {
|
||||
id: number
|
||||
name: string
|
||||
description: string
|
||||
status: string
|
||||
billing_model_source: string // "requested" | "upstream"
|
||||
restrict_models: boolean
|
||||
group_ids: number[]
|
||||
model_pricing: ChannelModelPricing[]
|
||||
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface CreateChannelRequest {
|
||||
name: string
|
||||
description?: string
|
||||
group_ids?: number[]
|
||||
model_pricing?: ChannelModelPricing[]
|
||||
model_mapping?: Record<string, Record<string, string>>
|
||||
billing_model_source?: string
|
||||
restrict_models?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateChannelRequest {
|
||||
name?: string
|
||||
description?: string
|
||||
status?: string
|
||||
group_ids?: number[]
|
||||
model_pricing?: ChannelModelPricing[]
|
||||
model_mapping?: Record<string, Record<string, string>>
|
||||
billing_model_source?: string
|
||||
restrict_models?: boolean
|
||||
}
|
||||
|
||||
interface PaginatedResponse<T> {
|
||||
items: T[]
|
||||
total: number
|
||||
}
|
||||
|
||||
/**
|
||||
* List channels with pagination
|
||||
*/
|
||||
export async function list(
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
filters?: {
|
||||
status?: string
|
||||
search?: string
|
||||
},
|
||||
options?: { signal?: AbortSignal }
|
||||
): Promise<PaginatedResponse<Channel>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<Channel>>('/admin/channels', {
|
||||
params: {
|
||||
page,
|
||||
page_size: pageSize,
|
||||
...filters
|
||||
},
|
||||
signal: options?.signal
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get channel by ID
|
||||
*/
|
||||
export async function getById(id: number): Promise<Channel> {
|
||||
const { data } = await apiClient.get<Channel>(`/admin/channels/${id}`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new channel
|
||||
*/
|
||||
export async function create(req: CreateChannelRequest): Promise<Channel> {
|
||||
const { data } = await apiClient.post<Channel>('/admin/channels', req)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a channel
|
||||
*/
|
||||
export async function update(id: number, req: UpdateChannelRequest): Promise<Channel> {
|
||||
const { data } = await apiClient.put<Channel>(`/admin/channels/${id}`, req)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a channel
|
||||
*/
|
||||
export async function remove(id: number): Promise<void> {
|
||||
await apiClient.delete(`/admin/channels/${id}`)
|
||||
}
|
||||
|
||||
export interface ModelDefaultPricing {
|
||||
found: boolean
|
||||
input_price?: number // per-token price
|
||||
output_price?: number
|
||||
cache_write_price?: number
|
||||
cache_read_price?: number
|
||||
image_output_price?: number
|
||||
}
|
||||
|
||||
export async function getModelDefaultPricing(model: string): Promise<ModelDefaultPricing> {
|
||||
const { data } = await apiClient.get<ModelDefaultPricing>('/admin/channels/model-pricing', {
|
||||
params: { model }
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
|
||||
export default channelsAPI
|
||||
@@ -167,6 +167,13 @@ export interface UserBreakdownParams {
|
||||
endpoint?: string
|
||||
endpoint_type?: 'inbound' | 'upstream' | 'path'
|
||||
limit?: number
|
||||
// Additional filter conditions
|
||||
user_id?: number
|
||||
api_key_id?: number
|
||||
account_id?: number
|
||||
request_type?: number
|
||||
stream?: boolean
|
||||
billing_type?: number | null
|
||||
}
|
||||
|
||||
export interface UserBreakdownResponse {
|
||||
|
||||
@@ -25,6 +25,7 @@ import apiKeysAPI from './apiKeys'
|
||||
import scheduledTestsAPI from './scheduledTests'
|
||||
import backupAPI from './backup'
|
||||
import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
|
||||
import channelsAPI from './channels'
|
||||
|
||||
/**
|
||||
* Unified admin API object for convenient access
|
||||
@@ -51,7 +52,8 @@ export const adminAPI = {
|
||||
apiKeys: apiKeysAPI,
|
||||
scheduledTests: scheduledTestsAPI,
|
||||
backup: backupAPI,
|
||||
tlsFingerprintProfiles: tlsFingerprintProfileAPI
|
||||
tlsFingerprintProfiles: tlsFingerprintProfileAPI,
|
||||
channels: channelsAPI
|
||||
}
|
||||
|
||||
export {
|
||||
@@ -76,7 +78,8 @@ export {
|
||||
apiKeysAPI,
|
||||
scheduledTestsAPI,
|
||||
backupAPI,
|
||||
tlsFingerprintProfileAPI
|
||||
tlsFingerprintProfileAPI,
|
||||
channelsAPI
|
||||
}
|
||||
|
||||
export default adminAPI
|
||||
|
||||
@@ -80,6 +80,7 @@ export interface CreateUsageCleanupTaskRequest {
|
||||
export interface AdminUsageQueryParams extends UsageQueryParams {
|
||||
user_id?: number
|
||||
exact_total?: boolean
|
||||
billing_mode?: string
|
||||
}
|
||||
|
||||
// ==================== API Functions ====================
|
||||
|
||||
113
frontend/src/components/admin/channel/IntervalRow.vue
Normal file
113
frontend/src/components/admin/channel/IntervalRow.vue
Normal file
@@ -0,0 +1,113 @@
|
||||
<template>
|
||||
<div class="flex items-start gap-2 rounded border p-2"
|
||||
:class="isEmpty ? 'border-red-400 bg-red-50 dark:border-red-500 dark:bg-red-950/20' : 'border-gray-200 bg-white dark:border-dark-500 dark:bg-dark-700'">
|
||||
<!-- Token mode: context range + prices ($/MTok) -->
|
||||
<template v-if="mode === 'token'">
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">Min</label>
|
||||
<input :value="interval.min_tokens" @input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
|
||||
type="number" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">Max <span class="text-gray-300">(含)</span></label>
|
||||
<input :value="interval.max_tokens ?? ''" @input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
|
||||
type="number" min="0" class="input mt-0.5 text-xs" :placeholder="'∞'" />
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', '输入') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$/M</span></label>
|
||||
<input :value="interval.input_price" @input="emitField('input_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', '输出') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$/M</span></label>
|
||||
<input :value="interval.output_price" @input="emitField('output_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', '缓存W') }} <span class="text-gray-300">$/M</span></label>
|
||||
<input :value="interval.cache_write_price" @input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', '缓存R') }} <span class="text-gray-300">$/M</span></label>
|
||||
<input :value="interval.cache_read_price" @input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Per-request / Image mode: tier label + context range + price -->
|
||||
<template v-else>
|
||||
<div class="w-24">
|
||||
<label class="text-xs text-gray-400">
|
||||
{{ mode === 'image' ? t('admin.channels.form.resolution', '分辨率') : t('admin.channels.form.tierLabel', '层级') }}
|
||||
</label>
|
||||
<input :value="interval.tier_label" @input="emitField('tier_label', ($event.target as HTMLInputElement).value)"
|
||||
type="text" class="input mt-0.5 text-xs" :placeholder="mode === 'image' ? '1K / 2K / 4K' : ''" />
|
||||
</div>
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">Min</label>
|
||||
<input :value="interval.min_tokens" @input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
|
||||
type="number" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">Max <span class="text-gray-300">(含)</span></label>
|
||||
<input :value="interval.max_tokens ?? ''" @input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
|
||||
type="number" min="0" class="input mt-0.5 text-xs" :placeholder="'∞'" />
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.perRequestPrice', '单次价格') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$</span></label>
|
||||
<input :value="interval.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<button type="button" @click="emit('remove')" class="mt-4 rounded p-0.5 text-gray-400 hover:text-red-500">
|
||||
<Icon name="x" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import type { IntervalFormEntry } from './types'
|
||||
import type { BillingMode } from '@/api/admin/channels'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
interval: IntervalFormEntry
|
||||
mode: BillingMode
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
update: [interval: IntervalFormEntry]
|
||||
remove: []
|
||||
}>()
|
||||
|
||||
// 检测所有价格字段是否都为空
|
||||
const isEmpty = computed(() => {
|
||||
const iv = props.interval
|
||||
return (iv.input_price == null || iv.input_price === '') &&
|
||||
(iv.output_price == null || iv.output_price === '') &&
|
||||
(iv.cache_write_price == null || iv.cache_write_price === '') &&
|
||||
(iv.cache_read_price == null || iv.cache_read_price === '') &&
|
||||
(iv.per_request_price == null || iv.per_request_price === '')
|
||||
})
|
||||
|
||||
function emitField(field: keyof IntervalFormEntry, value: string | number | null) {
|
||||
emit('update', { ...props.interval, [field]: value === '' ? null : value })
|
||||
}
|
||||
|
||||
function toInt(val: string): number {
|
||||
const n = parseInt(val, 10)
|
||||
return isNaN(n) ? 0 : n
|
||||
}
|
||||
|
||||
function toIntOrNull(val: string): number | null {
|
||||
if (val === '') return null
|
||||
const n = parseInt(val, 10)
|
||||
return isNaN(n) ? null : n
|
||||
}
|
||||
</script>
|
||||
89
frontend/src/components/admin/channel/ModelTagInput.vue
Normal file
89
frontend/src/components/admin/channel/ModelTagInput.vue
Normal file
@@ -0,0 +1,89 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- Tags display -->
|
||||
<div class="flex flex-wrap gap-1.5 rounded-lg border border-gray-200 bg-white p-2 dark:border-dark-600 dark:bg-dark-800 min-h-[2.5rem]">
|
||||
<span
|
||||
v-for="(model, idx) in models"
|
||||
:key="idx"
|
||||
class="inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-sm"
|
||||
:class="getPlatformTagClass(props.platform || '')"
|
||||
>
|
||||
{{ model }}
|
||||
<button
|
||||
type="button"
|
||||
@click="removeModel(idx)"
|
||||
class="ml-0.5 rounded-full p-0.5 hover:bg-primary-200 dark:hover:bg-primary-800"
|
||||
>
|
||||
<Icon name="x" size="xs" />
|
||||
</button>
|
||||
</span>
|
||||
<input
|
||||
ref="inputRef"
|
||||
v-model="inputValue"
|
||||
type="text"
|
||||
class="flex-1 min-w-[120px] border-none bg-transparent text-sm outline-none placeholder:text-gray-400 dark:text-white"
|
||||
:placeholder="models.length === 0 ? placeholder : ''"
|
||||
@keydown.enter.prevent="addModel"
|
||||
@keydown.tab.prevent="addModel"
|
||||
@keydown.delete="handleBackspace"
|
||||
@paste="handlePaste"
|
||||
/>
|
||||
</div>
|
||||
<p class="mt-1 text-xs text-gray-400">
|
||||
{{ t('admin.channels.form.modelInputHint', 'Press Enter to add, supports paste for batch import.') }}
|
||||
</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { getPlatformTagClass } from './types'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
models: string[]
|
||||
placeholder?: string
|
||||
platform?: string
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:models': [models: string[]]
|
||||
}>()
|
||||
|
||||
const inputValue = ref('')
|
||||
const inputRef = ref<HTMLInputElement>()
|
||||
|
||||
function addModel() {
|
||||
const val = inputValue.value.trim()
|
||||
if (!val) return
|
||||
if (!props.models.includes(val)) {
|
||||
emit('update:models', [...props.models, val])
|
||||
}
|
||||
inputValue.value = ''
|
||||
}
|
||||
|
||||
function removeModel(idx: number) {
|
||||
const newModels = [...props.models]
|
||||
newModels.splice(idx, 1)
|
||||
emit('update:models', newModels)
|
||||
}
|
||||
|
||||
function handleBackspace() {
|
||||
if (inputValue.value === '' && props.models.length > 0) {
|
||||
removeModel(props.models.length - 1)
|
||||
}
|
||||
}
|
||||
|
||||
function handlePaste(e: ClipboardEvent) {
|
||||
e.preventDefault()
|
||||
const text = e.clipboardData?.getData('text') || ''
|
||||
const items = text.split(/[,\n;]+/).map(s => s.trim()).filter(Boolean)
|
||||
if (items.length === 0) return
|
||||
const unique = [...new Set([...props.models, ...items])]
|
||||
emit('update:models', unique)
|
||||
inputValue.value = ''
|
||||
}
|
||||
</script>
|
||||
354
frontend/src/components/admin/channel/PricingEntryCard.vue
Normal file
354
frontend/src/components/admin/channel/PricingEntryCard.vue
Normal file
@@ -0,0 +1,354 @@
|
||||
<template>
|
||||
<div class="rounded-lg border border-gray-200 bg-gray-50 p-3 dark:border-dark-600 dark:bg-dark-800">
|
||||
<!-- Collapsed summary header (clickable) -->
|
||||
<div
|
||||
class="flex cursor-pointer select-none items-center gap-2"
|
||||
@click="collapsed = !collapsed"
|
||||
>
|
||||
<Icon
|
||||
:name="collapsed ? 'chevronRight' : 'chevronDown'"
|
||||
size="sm"
|
||||
:stroke-width="2"
|
||||
class="flex-shrink-0 text-gray-400 transition-transform duration-200"
|
||||
/>
|
||||
|
||||
<!-- Summary: model tags + billing badge -->
|
||||
<div v-if="collapsed" class="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
|
||||
<!-- Compact model tags (show first 3) -->
|
||||
<div class="flex min-w-0 flex-1 flex-wrap items-center gap-1">
|
||||
<span
|
||||
v-for="(m, i) in entry.models.slice(0, 3)"
|
||||
:key="i"
|
||||
class="inline-flex shrink-0 rounded px-1.5 py-0.5 text-xs"
|
||||
:class="getPlatformTagClass(props.platform || '')"
|
||||
>
|
||||
{{ m }}
|
||||
</span>
|
||||
<span
|
||||
v-if="entry.models.length > 3"
|
||||
class="whitespace-nowrap text-xs text-gray-400"
|
||||
>
|
||||
+{{ entry.models.length - 3 }}
|
||||
</span>
|
||||
<span
|
||||
v-if="entry.models.length === 0"
|
||||
class="text-xs italic text-gray-400"
|
||||
>
|
||||
{{ t('admin.channels.form.noModels', '未添加模型') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Billing mode badge -->
|
||||
<span
|
||||
class="flex-shrink-0 rounded-full bg-primary-100 px-2 py-0.5 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
|
||||
>
|
||||
{{ billingModeLabel }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Expanded: show the label "Pricing Entry" or similar -->
|
||||
<div v-else class="flex-1 text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.pricingEntry', '定价配置') }}
|
||||
</div>
|
||||
|
||||
<!-- Remove button (always visible, stop propagation) -->
|
||||
<button
|
||||
type="button"
|
||||
@click.stop="emit('remove')"
|
||||
class="flex-shrink-0 rounded p-1 text-gray-400 hover:text-red-500"
|
||||
>
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expandable content with transition -->
|
||||
<div
|
||||
class="collapsible-content"
|
||||
:class="{ 'collapsible-content--collapsed': collapsed }"
|
||||
>
|
||||
<div class="collapsible-inner">
|
||||
<!-- Header: Models + Billing Mode -->
|
||||
<div class="mt-3 flex items-start gap-2">
|
||||
<div class="flex-1">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.models', '模型列表') }} <span class="text-red-500">*</span>
|
||||
</label>
|
||||
<ModelTagInput
|
||||
:models="entry.models"
|
||||
:platform="props.platform"
|
||||
@update:models="onModelsUpdate($event)"
|
||||
:placeholder="t('admin.channels.form.modelsPlaceholder', '输入模型名后按回车添加,支持通配符 *')"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
<div class="w-40">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.billingMode', '计费模式') }}
|
||||
</label>
|
||||
<Select
|
||||
:modelValue="entry.billing_mode"
|
||||
@update:modelValue="emit('update', { ...entry, billing_mode: $event as BillingMode, intervals: [] })"
|
||||
:options="billingModeOptions"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Token mode -->
|
||||
<div v-if="entry.billing_mode === 'token'">
|
||||
<!-- Default prices (fallback when no interval matches) -->
|
||||
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.defaultPrices', '默认价格(未命中区间时使用)') }}
|
||||
<span class="ml-1 font-normal text-gray-400">$/MTok</span>
|
||||
</label>
|
||||
<div class="mt-1 grid grid-cols-2 gap-2 sm:grid-cols-5">
|
||||
<div>
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', '输入') }}</label>
|
||||
<input :value="entry.input_price" @input="emitField('input_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', '输出') }}</label>
|
||||
<input :value="entry.output_price" @input="emitField('output_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', '缓存写入') }}</label>
|
||||
<input :value="entry.cache_write_price" @input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', '缓存读取') }}</label>
|
||||
<input :value="entry.cache_read_price" @input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.imageTokenPrice', '图片输出') }}</label>
|
||||
<input :value="entry.image_output_price" @input="emitField('image_output_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Token intervals -->
|
||||
<div class="mt-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.intervals', '上下文区间定价(可选)') }}
|
||||
<span class="ml-1 font-normal text-gray-400">(min, max]</span>
|
||||
</label>
|
||||
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('admin.channels.form.addInterval', '添加区间') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
|
||||
<IntervalRow
|
||||
v-for="(iv, idx) in entry.intervals"
|
||||
:key="idx"
|
||||
:interval="iv"
|
||||
:mode="entry.billing_mode"
|
||||
@update="updateInterval(idx, $event)"
|
||||
@remove="removeInterval(idx)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Per-request mode -->
|
||||
<div v-else-if="entry.billing_mode === 'per_request'">
|
||||
<!-- Default per-request price -->
|
||||
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.defaultPerRequestPrice', '默认单次价格(未命中层级时使用)') }}
|
||||
<span class="ml-1 font-normal text-gray-400">$</span>
|
||||
</label>
|
||||
<div class="mt-1 w-48">
|
||||
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
|
||||
<!-- Tiers -->
|
||||
<div class="mt-3 flex items-center justify-between">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.requestTiers', '按次计费层级') }}
|
||||
</label>
|
||||
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('admin.channels.form.addTier', '添加层级') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
|
||||
<IntervalRow
|
||||
v-for="(iv, idx) in entry.intervals"
|
||||
:key="idx"
|
||||
:interval="iv"
|
||||
:mode="entry.billing_mode"
|
||||
@update="updateInterval(idx, $event)"
|
||||
@remove="removeInterval(idx)"
|
||||
/>
|
||||
</div>
|
||||
<div v-else class="mt-2 rounded border border-dashed border-gray-300 p-3 text-center text-xs text-gray-400 dark:border-dark-500">
|
||||
{{ t('admin.channels.form.noTiersYet', '暂无层级,点击添加配置按次计费价格') }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Image mode -->
|
||||
<div v-else-if="entry.billing_mode === 'image'">
|
||||
<!-- Default image price (per-request, same as per_request mode) -->
|
||||
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.defaultImagePrice', '默认图片价格(未命中层级时使用)') }}
|
||||
<span class="ml-1 font-normal text-gray-400">$</span>
|
||||
</label>
|
||||
<div class="mt-1 w-48">
|
||||
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
|
||||
</div>
|
||||
|
||||
<!-- Image tiers -->
|
||||
<div class="mt-3 flex items-center justify-between">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.imageTiers', '图片计费层级(按次)') }}
|
||||
</label>
|
||||
<button type="button" @click="addImageTier" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('admin.channels.form.addTier', '添加层级') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
|
||||
<IntervalRow
|
||||
v-for="(iv, idx) in entry.intervals"
|
||||
:key="idx"
|
||||
:interval="iv"
|
||||
:mode="entry.billing_mode"
|
||||
@update="updateInterval(idx, $event)"
|
||||
@remove="removeInterval(idx)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import IntervalRow from './IntervalRow.vue'
|
||||
import ModelTagInput from './ModelTagInput.vue'
|
||||
import type { PricingFormEntry, IntervalFormEntry } from './types'
|
||||
import { perTokenToMTok, getPlatformTagClass } from './types'
|
||||
import type { BillingMode } from '@/api/admin/channels'
|
||||
import channelsAPI from '@/api/admin/channels'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
entry: PricingFormEntry
|
||||
platform?: string
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
update: [entry: PricingFormEntry]
|
||||
remove: []
|
||||
}>()
|
||||
|
||||
// Collapse state: entries with existing models default to collapsed
|
||||
const collapsed = ref(props.entry.models.length > 0)
|
||||
|
||||
const billingModeOptions = computed(() => [
|
||||
{ value: 'token', label: 'Token' },
|
||||
{ value: 'per_request', label: t('admin.channels.billingMode.perRequest', '按次') },
|
||||
{ value: 'image', label: t('admin.channels.billingMode.image', '图片(按次)') }
|
||||
])
|
||||
|
||||
const billingModeLabel = computed(() => {
|
||||
const opt = billingModeOptions.value.find(o => o.value === props.entry.billing_mode)
|
||||
return opt ? opt.label : props.entry.billing_mode
|
||||
})
|
||||
|
||||
function emitField(field: keyof PricingFormEntry, value: string) {
|
||||
emit('update', { ...props.entry, [field]: value === '' ? null : value })
|
||||
}
|
||||
|
||||
function addInterval() {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
intervals.push({
|
||||
min_tokens: 0, max_tokens: null, tier_label: '',
|
||||
input_price: null, output_price: null, cache_write_price: null,
|
||||
cache_read_price: null, per_request_price: null,
|
||||
sort_order: intervals.length
|
||||
})
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
function addImageTier() {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
const labels = ['1K', '2K', '4K', 'HD']
|
||||
intervals.push({
|
||||
min_tokens: 0, max_tokens: null, tier_label: labels[intervals.length] || '',
|
||||
input_price: null, output_price: null, cache_write_price: null,
|
||||
cache_read_price: null, per_request_price: null,
|
||||
sort_order: intervals.length
|
||||
})
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
function updateInterval(idx: number, updated: IntervalFormEntry) {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
intervals[idx] = updated
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
function removeInterval(idx: number) {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
intervals.splice(idx, 1)
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
async function onModelsUpdate(newModels: string[]) {
|
||||
const oldModels = props.entry.models
|
||||
emit('update', { ...props.entry, models: newModels })
|
||||
|
||||
// 只在新增模型且当前无价格时自动填充
|
||||
const addedModels = newModels.filter(m => !oldModels.includes(m))
|
||||
if (addedModels.length === 0) return
|
||||
|
||||
// 检查是否所有价格字段都为空
|
||||
const e = props.entry
|
||||
const hasPrice = e.input_price != null || e.output_price != null ||
|
||||
e.cache_write_price != null || e.cache_read_price != null
|
||||
if (hasPrice) return
|
||||
|
||||
// 查询第一个新增模型的默认价格
|
||||
try {
|
||||
const result = await channelsAPI.getModelDefaultPricing(addedModels[0])
|
||||
if (result.found) {
|
||||
emit('update', {
|
||||
...props.entry,
|
||||
models: newModels,
|
||||
input_price: perTokenToMTok(result.input_price ?? null),
|
||||
output_price: perTokenToMTok(result.output_price ?? null),
|
||||
cache_write_price: perTokenToMTok(result.cache_write_price ?? null),
|
||||
cache_read_price: perTokenToMTok(result.cache_read_price ?? null),
|
||||
image_output_price: perTokenToMTok(result.image_output_price ?? null),
|
||||
})
|
||||
}
|
||||
} catch {
|
||||
// 查询失败不影响用户操作
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.collapsible-content {
|
||||
display: grid;
|
||||
grid-template-rows: 1fr;
|
||||
transition: grid-template-rows 0.25s ease;
|
||||
}
|
||||
|
||||
.collapsible-content--collapsed {
|
||||
grid-template-rows: 0fr;
|
||||
}
|
||||
|
||||
.collapsible-inner {
|
||||
overflow: hidden;
|
||||
}
|
||||
</style>
|
||||
190
frontend/src/components/admin/channel/types.ts
Normal file
190
frontend/src/components/admin/channel/types.ts
Normal file
@@ -0,0 +1,190 @@
|
||||
import type { BillingMode, PricingInterval } from '@/api/admin/channels'
|
||||
|
||||
export interface IntervalFormEntry {
|
||||
min_tokens: number
|
||||
max_tokens: number | null
|
||||
tier_label: string
|
||||
input_price: number | string | null
|
||||
output_price: number | string | null
|
||||
cache_write_price: number | string | null
|
||||
cache_read_price: number | string | null
|
||||
per_request_price: number | string | null
|
||||
sort_order: number
|
||||
}
|
||||
|
||||
export interface PricingFormEntry {
|
||||
models: string[]
|
||||
billing_mode: BillingMode
|
||||
input_price: number | string | null
|
||||
output_price: number | string | null
|
||||
cache_write_price: number | string | null
|
||||
cache_read_price: number | string | null
|
||||
image_output_price: number | string | null
|
||||
per_request_price: number | string | null
|
||||
intervals: IntervalFormEntry[]
|
||||
}
|
||||
|
||||
// 价格转换:后端存 per-token,前端显示 per-MTok ($/1M tokens)
|
||||
const MTOK = 1_000_000
|
||||
|
||||
export function toNullableNumber(val: number | string | null | undefined): number | null {
|
||||
if (val === null || val === undefined || val === '') return null
|
||||
const num = Number(val)
|
||||
return isNaN(num) ? null : num
|
||||
}
|
||||
|
||||
/** 前端显示值($/MTok) → 后端存储值(per-token) */
|
||||
export function mTokToPerToken(val: number | string | null | undefined): number | null {
|
||||
const num = toNullableNumber(val)
|
||||
return num === null ? null : parseFloat((num / MTOK).toPrecision(10))
|
||||
}
|
||||
|
||||
/** 后端存储值(per-token) → 前端显示值($/MTok) */
|
||||
export function perTokenToMTok(val: number | null | undefined): number | null {
|
||||
if (val === null || val === undefined) return null
|
||||
// toPrecision(10) 消除 IEEE 754 浮点乘法精度误差,如 5e-8 * 1e6 = 0.04999...96 → 0.05
|
||||
return parseFloat((val * MTOK).toPrecision(10))
|
||||
}
|
||||
|
||||
export function apiIntervalsToForm(intervals: PricingInterval[]): IntervalFormEntry[] {
|
||||
return (intervals || []).map(iv => ({
|
||||
min_tokens: iv.min_tokens,
|
||||
max_tokens: iv.max_tokens,
|
||||
tier_label: iv.tier_label || '',
|
||||
input_price: perTokenToMTok(iv.input_price),
|
||||
output_price: perTokenToMTok(iv.output_price),
|
||||
cache_write_price: perTokenToMTok(iv.cache_write_price),
|
||||
cache_read_price: perTokenToMTok(iv.cache_read_price),
|
||||
per_request_price: iv.per_request_price,
|
||||
sort_order: iv.sort_order
|
||||
}))
|
||||
}
|
||||
|
||||
export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInterval[] {
|
||||
return (intervals || []).map(iv => ({
|
||||
min_tokens: iv.min_tokens,
|
||||
max_tokens: iv.max_tokens,
|
||||
tier_label: iv.tier_label,
|
||||
input_price: mTokToPerToken(iv.input_price),
|
||||
output_price: mTokToPerToken(iv.output_price),
|
||||
cache_write_price: mTokToPerToken(iv.cache_write_price),
|
||||
cache_read_price: mTokToPerToken(iv.cache_read_price),
|
||||
per_request_price: toNullableNumber(iv.per_request_price),
|
||||
sort_order: iv.sort_order
|
||||
}))
|
||||
}
|
||||
|
||||
// ── 模型模式冲突检测 ──────────────────────────────────────
|
||||
|
||||
interface ModelPattern {
|
||||
pattern: string
|
||||
prefix: string // lowercase, 通配符去掉尾部 *
|
||||
wildcard: boolean
|
||||
}
|
||||
|
||||
function toModelPattern(model: string): ModelPattern {
|
||||
const lower = model.toLowerCase()
|
||||
const wildcard = lower.endsWith('*')
|
||||
return {
|
||||
pattern: model,
|
||||
prefix: wildcard ? lower.slice(0, -1) : lower,
|
||||
wildcard,
|
||||
}
|
||||
}
|
||||
|
||||
function patternsConflict(a: ModelPattern, b: ModelPattern): boolean {
|
||||
if (!a.wildcard && !b.wildcard) return a.prefix === b.prefix
|
||||
if (a.wildcard && !b.wildcard) return b.prefix.startsWith(a.prefix)
|
||||
if (!a.wildcard && b.wildcard) return a.prefix.startsWith(b.prefix)
|
||||
// 双通配符:任一前缀是另一前缀的前缀即冲突
|
||||
return a.prefix.startsWith(b.prefix) || b.prefix.startsWith(a.prefix)
|
||||
}
|
||||
|
||||
/** 检测模型模式列表中的冲突,返回冲突的两个模式名;无冲突返回 null */
|
||||
export function findModelConflict(models: string[]): [string, string] | null {
|
||||
const patterns = models.map(toModelPattern)
|
||||
for (let i = 0; i < patterns.length; i++) {
|
||||
for (let j = i + 1; j < patterns.length; j++) {
|
||||
if (patternsConflict(patterns[i], patterns[j])) {
|
||||
return [patterns[i].pattern, patterns[j].pattern]
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
// ── 区间校验 ──────────────────────────────────────────────
|
||||
|
||||
/** 校验区间列表的合法性,返回错误消息;通过则返回 null */
|
||||
export function validateIntervals(intervals: IntervalFormEntry[]): string | null {
|
||||
if (!intervals || intervals.length === 0) return null
|
||||
|
||||
// 按 min_tokens 排序(不修改原数组)
|
||||
const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens)
|
||||
|
||||
for (let i = 0; i < sorted.length; i++) {
|
||||
const err = validateSingleInterval(sorted[i], i)
|
||||
if (err) return err
|
||||
}
|
||||
return checkIntervalOverlap(sorted)
|
||||
}
|
||||
|
||||
function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null {
|
||||
if (iv.min_tokens < 0) {
|
||||
return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数`
|
||||
}
|
||||
if (iv.max_tokens != null) {
|
||||
if (iv.max_tokens <= 0) {
|
||||
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0`
|
||||
}
|
||||
if (iv.max_tokens <= iv.min_tokens) {
|
||||
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})`
|
||||
}
|
||||
}
|
||||
return validateIntervalPrices(iv, idx)
|
||||
}
|
||||
|
||||
function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null {
|
||||
const prices: [string, number | string | null][] = [
|
||||
['输入价格', iv.input_price],
|
||||
['输出价格', iv.output_price],
|
||||
['缓存写入价格', iv.cache_write_price],
|
||||
['缓存读取价格', iv.cache_read_price],
|
||||
['单次价格', iv.per_request_price],
|
||||
]
|
||||
for (const [name, val] of prices) {
|
||||
if (val != null && val !== '' && Number(val) < 0) {
|
||||
return `区间 #${idx + 1}: ${name}不能为负数`
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null {
|
||||
for (let i = 0; i < sorted.length; i++) {
|
||||
// 无上限区间必须是最后一个
|
||||
if (sorted[i].max_tokens == null && i < sorted.length - 1) {
|
||||
return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个`
|
||||
}
|
||||
if (i === 0) continue
|
||||
const prev = sorted[i - 1]
|
||||
// (min, max] 语义:前一个区间上界 > 当前区间下界则重叠
|
||||
if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) {
|
||||
const prevMax = prev.max_tokens == null ? '∞' : String(prev.max_tokens)
|
||||
return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})`
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/** 平台对应的模型 tag 样式(背景+文字) */
|
||||
export function getPlatformTagClass(platform: string): string {
|
||||
switch (platform) {
|
||||
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
case 'sora': return 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400'
|
||||
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
|
||||
}
|
||||
}
|
||||
@@ -133,6 +133,12 @@
|
||||
<Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" />
|
||||
</div>
|
||||
|
||||
<!-- Billing Mode Filter -->
|
||||
<div class="w-full sm:w-auto sm:min-w-[200px]">
|
||||
<label class="input-label">{{ t('admin.usage.billingMode') }}</label>
|
||||
<Select v-model="filters.billing_mode" :options="billingModeOptions" @change="emitChange" />
|
||||
</div>
|
||||
|
||||
<!-- Group Filter -->
|
||||
<div class="w-full sm:w-auto sm:min-w-[200px]">
|
||||
<label class="input-label">{{ t('admin.usage.group') }}</label>
|
||||
@@ -232,6 +238,13 @@ const billingTypeOptions = ref<SelectOption[]>([
|
||||
{ value: 1, label: t('admin.usage.billingTypeSubscription') }
|
||||
])
|
||||
|
||||
const billingModeOptions = ref<SelectOption[]>([
|
||||
{ value: null, label: t('admin.usage.allBillingModes') },
|
||||
{ value: 'token', label: t('admin.usage.billingModeToken') },
|
||||
{ value: 'per_request', label: t('admin.usage.billingModePerRequest') },
|
||||
{ value: 'image', label: t('admin.usage.billingModeImage') }
|
||||
])
|
||||
|
||||
const emitChange = () => emit('change')
|
||||
|
||||
const debounceUserSearch = () => {
|
||||
|
||||
@@ -26,7 +26,15 @@
|
||||
</template>
|
||||
|
||||
<template #cell-model="{ row }">
|
||||
<div v-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
|
||||
<div v-if="row.model_mapping_chain && row.model_mapping_chain.includes('→')" class="space-y-0.5 text-xs">
|
||||
<div v-for="(step, i) in row.model_mapping_chain.split('→')" :key="i"
|
||||
class="break-all"
|
||||
:class="i === 0 ? 'font-medium text-gray-900 dark:text-white' : 'text-gray-500 dark:text-gray-400'"
|
||||
:style="i > 0 ? `padding-left: ${i * 0.75}rem` : ''">
|
||||
<span v-if="i > 0" class="mr-0.5">↳</span>{{ step }}
|
||||
</div>
|
||||
</div>
|
||||
<div v-else-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
|
||||
<div class="break-all font-medium text-gray-900 dark:text-white">
|
||||
{{ row.model }}
|
||||
</div>
|
||||
@@ -69,9 +77,15 @@
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-billing_mode="{ row }">
|
||||
<span class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium" :class="getBillingModeBadgeClass(row.billing_mode)">
|
||||
{{ getBillingModeLabel(row.billing_mode) }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-tokens="{ row }">
|
||||
<!-- 图片生成请求 -->
|
||||
<div v-if="row.image_count > 0" class="flex items-center gap-1.5">
|
||||
<!-- 图片生成请求(仅按次计费时显示图片格式) -->
|
||||
<div v-if="row.image_count > 0 && row.billing_mode === 'image'" class="flex items-center gap-1.5">
|
||||
<svg class="h-4 w-4 text-indigo-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z" />
|
||||
</svg>
|
||||
@@ -281,11 +295,11 @@
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.rate') }}</span>
|
||||
<span class="font-semibold text-blue-400">{{ (tooltipData?.rate_multiplier || 1).toFixed(2) }}x</span>
|
||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
|
||||
<span class="font-semibold text-blue-400">{{ (tooltipData?.account_rate_multiplier ?? 1).toFixed(2) }}x</span>
|
||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.original') }}</span>
|
||||
@@ -312,6 +326,7 @@
|
||||
import { ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { formatDateTime, formatReasoningEffort } from '@/utils/format'
|
||||
import { formatCacheTokens, formatMultiplier } from '@/utils/formatters'
|
||||
import { formatTokenPricePerMillion } from '@/utils/usagePricing'
|
||||
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
|
||||
import { resolveUsageRequestType } from '@/utils/usageRequestType'
|
||||
@@ -350,12 +365,19 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => {
|
||||
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
|
||||
}
|
||||
|
||||
const formatCacheTokens = (tokens: number): string => {
|
||||
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
|
||||
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
|
||||
return tokens.toString()
|
||||
const getBillingModeLabel = (mode: string | null | undefined): string => {
|
||||
if (mode === 'per_request') return t('admin.usage.billingModePerRequest')
|
||||
if (mode === 'image') return t('admin.usage.billingModeImage')
|
||||
return t('admin.usage.billingModeToken')
|
||||
}
|
||||
|
||||
const getBillingModeBadgeClass = (mode: string | null | undefined): string => {
|
||||
if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200'
|
||||
if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
|
||||
return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
|
||||
}
|
||||
|
||||
|
||||
const formatUserAgent = (ua: string): string => {
|
||||
return ua
|
||||
}
|
||||
|
||||
@@ -161,6 +161,7 @@ const props = withDefaults(
|
||||
showSourceToggle?: boolean
|
||||
startDate?: string
|
||||
endDate?: string
|
||||
filters?: Record<string, any>
|
||||
}>(),
|
||||
{
|
||||
upstreamEndpointStats: () => [],
|
||||
@@ -193,6 +194,7 @@ const toggleBreakdown = async (endpoint: string) => {
|
||||
breakdownItems.value = []
|
||||
try {
|
||||
const res = await getUserBreakdown({
|
||||
...props.filters,
|
||||
start_date: props.startDate,
|
||||
end_date: props.endDate,
|
||||
endpoint,
|
||||
|
||||
@@ -125,6 +125,7 @@ const props = withDefaults(defineProps<{
|
||||
showMetricToggle?: boolean
|
||||
startDate?: string
|
||||
endDate?: string
|
||||
filters?: Record<string, any>
|
||||
}>(), {
|
||||
loading: false,
|
||||
metric: 'tokens',
|
||||
@@ -150,6 +151,7 @@ const toggleBreakdown = async (type: string, id: number | string) => {
|
||||
breakdownItems.value = []
|
||||
try {
|
||||
const res = await getUserBreakdown({
|
||||
...props.filters,
|
||||
start_date: props.startDate,
|
||||
end_date: props.endDate,
|
||||
group_id: Number(id),
|
||||
|
||||
@@ -270,6 +270,7 @@ const props = withDefaults(defineProps<{
|
||||
rankingError?: boolean
|
||||
startDate?: string
|
||||
endDate?: string
|
||||
filters?: Record<string, any>
|
||||
}>(), {
|
||||
upstreamModelStats: () => [],
|
||||
mappingModelStats: () => [],
|
||||
@@ -302,6 +303,7 @@ const toggleBreakdown = async (type: string, id: string) => {
|
||||
breakdownItems.value = []
|
||||
try {
|
||||
const res = await getUserBreakdown({
|
||||
...props.filters,
|
||||
start_date: props.startDate,
|
||||
end_date: props.endDate,
|
||||
model: id,
|
||||
|
||||
@@ -287,6 +287,21 @@ const FolderIcon = {
|
||||
)
|
||||
}
|
||||
|
||||
const ChannelIcon = {
|
||||
render: () =>
|
||||
h(
|
||||
'svg',
|
||||
{ fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
|
||||
[
|
||||
h('path', {
|
||||
'stroke-linecap': 'round',
|
||||
'stroke-linejoin': 'round',
|
||||
d: 'M6.429 9.75L2.25 12l4.179 2.25m0-4.5l5.571 3 5.571-3m-11.142 0L2.25 7.5 12 2.25l9.75 5.25-4.179 2.25m0 0l4.179 2.25L12 17.25 2.25 12m15.321-2.25l4.179 2.25L12 17.25l-9.75-5.25'
|
||||
})
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
const CreditCardIcon = {
|
||||
render: () =>
|
||||
h(
|
||||
@@ -568,6 +583,7 @@ const adminNavItems = computed((): NavItem[] => {
|
||||
: []),
|
||||
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon },
|
||||
{ path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon },
|
||||
|
||||
@@ -335,6 +335,7 @@ export default {
|
||||
profile: 'Profile',
|
||||
users: 'Users',
|
||||
groups: 'Groups',
|
||||
channels: 'Channels',
|
||||
subscriptions: 'Subscriptions',
|
||||
accounts: 'Accounts',
|
||||
proxies: 'Proxies',
|
||||
@@ -1719,6 +1720,107 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
// Channel Management
|
||||
channels: {
|
||||
title: 'Channel Management',
|
||||
description: 'Manage channels and custom model pricing',
|
||||
searchChannels: 'Search channels...',
|
||||
createChannel: 'Create Channel',
|
||||
editChannel: 'Edit Channel',
|
||||
deleteChannel: 'Delete Channel',
|
||||
statusActive: 'Active',
|
||||
statusDisabled: 'Disabled',
|
||||
allStatus: 'All Status',
|
||||
groupsUnit: 'groups',
|
||||
pricingUnit: 'pricing rules',
|
||||
noChannelsYet: 'No Channels Yet',
|
||||
createFirstChannel: 'Create your first channel to manage model pricing',
|
||||
loadError: 'Failed to load channels',
|
||||
createSuccess: 'Channel created',
|
||||
updateSuccess: 'Channel updated',
|
||||
deleteSuccess: 'Channel deleted',
|
||||
createError: 'Failed to create channel',
|
||||
updateError: 'Failed to update channel',
|
||||
deleteError: 'Failed to delete channel',
|
||||
nameRequired: 'Please enter a channel name',
|
||||
duplicateModels: 'Model "{0}" appears in multiple pricing entries',
|
||||
modelConflict: "Model patterns '{model1}' and '{model2}' conflict: overlapping match range",
|
||||
mappingConflict: "Mapping source patterns '{model1}' and '{model2}' conflict: overlapping match range",
|
||||
deleteConfirm: 'Are you sure you want to delete channel "{name}"? This cannot be undone.',
|
||||
columns: {
|
||||
name: 'Name',
|
||||
description: 'Description',
|
||||
status: 'Status',
|
||||
groups: 'Groups',
|
||||
pricing: 'Pricing',
|
||||
createdAt: 'Created',
|
||||
actions: 'Actions'
|
||||
},
|
||||
billingMode: {
|
||||
token: 'Token',
|
||||
perRequest: 'Per Request',
|
||||
image: 'Image (Per Request)'
|
||||
},
|
||||
form: {
|
||||
name: 'Name',
|
||||
namePlaceholder: 'Enter channel name',
|
||||
description: 'Description',
|
||||
descriptionPlaceholder: 'Optional description',
|
||||
status: 'Status',
|
||||
groups: 'Associated Groups',
|
||||
noGroupsAvailable: 'No groups available',
|
||||
inOtherChannel: 'In "{name}"',
|
||||
modelPricing: 'Model Pricing',
|
||||
models: 'Models',
|
||||
modelsPlaceholder: 'Type full model name and press Enter',
|
||||
modelInputHint: 'Press Enter to add, supports paste for batch import.',
|
||||
billingMode: 'Billing Mode',
|
||||
defaultPrices: 'Default prices (fallback when no interval matches)',
|
||||
inputPrice: 'Input',
|
||||
outputPrice: 'Output',
|
||||
cacheWritePrice: 'Cache Write',
|
||||
cacheReadPrice: 'Cache Read',
|
||||
imageTokenPrice: 'Image Output',
|
||||
imageOutputPrice: 'Image Output Price',
|
||||
pricePlaceholder: 'Default',
|
||||
intervals: 'Context Intervals (optional)',
|
||||
addInterval: 'Add Interval',
|
||||
requestTiers: 'Request Tiers',
|
||||
imageTiers: 'Image Tiers (Per Request)',
|
||||
addTier: 'Add Tier',
|
||||
noTiersYet: 'No tiers yet. Click add to configure per-request pricing.',
|
||||
noPricingRules: 'No pricing rules yet. Click "Add" to create one.',
|
||||
perRequestPrice: 'Price per Request',
|
||||
perRequestPriceRequired: 'Per-request price or billing tiers required for per-request/image billing mode',
|
||||
tierLabel: 'Tier',
|
||||
resolution: 'Resolution',
|
||||
modelMapping: 'Model Mapping',
|
||||
modelMappingHint: 'Map request model names to actual model names. Runs before account-level mapping.',
|
||||
noMappingRules: 'No mapping rules. Click "Add" to create one.',
|
||||
mappingSource: 'Source model',
|
||||
mappingTarget: 'Target model',
|
||||
billingModelSource: 'Billing Model',
|
||||
billingModelSourceChannelMapped: 'Bill by channel-mapped model',
|
||||
billingModelSourceRequested: 'Bill by requested model',
|
||||
billingModelSourceUpstream: 'Bill by final upstream model',
|
||||
billingModelSourceHint: 'Controls which model name is used for pricing lookup',
|
||||
selectedCount: '{count} selected',
|
||||
searchGroups: 'Search groups...',
|
||||
noGroupsMatch: 'No groups match your search',
|
||||
restrictModels: 'Restrict Models',
|
||||
restrictModelsHint: 'When enabled, only models in the pricing list are allowed. Others will be rejected.',
|
||||
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
|
||||
defaultImagePrice: 'Default image price (fallback when no tier matches)',
|
||||
platformConfig: 'Platform Configuration',
|
||||
basicSettings: 'Basic Settings',
|
||||
addPlatform: 'Add Platform',
|
||||
noPlatforms: 'Click "Add Platform" to start configuring the channel',
|
||||
mappingCount: 'mappings',
|
||||
pricingEntry: 'Pricing Entry',
|
||||
noModels: 'No models added'
|
||||
}
|
||||
},
|
||||
|
||||
// Subscriptions
|
||||
subscriptions: {
|
||||
title: 'Subscription Management',
|
||||
@@ -3258,6 +3360,11 @@ export default {
|
||||
allBillingTypes: 'All Billing Types',
|
||||
billingTypeBalance: 'Balance',
|
||||
billingTypeSubscription: 'Subscription',
|
||||
billingMode: 'Billing Mode',
|
||||
billingModeToken: 'Token',
|
||||
billingModePerRequest: 'Per Request',
|
||||
billingModeImage: 'Image',
|
||||
allBillingModes: 'All Billing Modes',
|
||||
ipAddress: 'IP',
|
||||
clickToViewBalance: 'Click to view balance history',
|
||||
failedToLoadUser: 'Failed to load user info',
|
||||
|
||||
@@ -335,6 +335,7 @@ export default {
|
||||
profile: '个人资料',
|
||||
users: '用户管理',
|
||||
groups: '分组管理',
|
||||
channels: '渠道管理',
|
||||
subscriptions: '订阅管理',
|
||||
accounts: '账号管理',
|
||||
proxies: 'IP管理',
|
||||
@@ -1799,6 +1800,107 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
// Channel Management
|
||||
channels: {
|
||||
title: '渠道管理',
|
||||
description: '管理渠道和自定义模型定价',
|
||||
searchChannels: '搜索渠道...',
|
||||
createChannel: '创建渠道',
|
||||
editChannel: '编辑渠道',
|
||||
deleteChannel: '删除渠道',
|
||||
statusActive: '启用',
|
||||
statusDisabled: '停用',
|
||||
allStatus: '全部状态',
|
||||
groupsUnit: '个分组',
|
||||
pricingUnit: '条定价',
|
||||
noChannelsYet: '暂无渠道',
|
||||
createFirstChannel: '创建第一个渠道来管理模型定价',
|
||||
loadError: '加载渠道列表失败',
|
||||
createSuccess: '渠道创建成功',
|
||||
updateSuccess: '渠道更新成功',
|
||||
deleteSuccess: '渠道删除成功',
|
||||
createError: '创建渠道失败',
|
||||
updateError: '更新渠道失败',
|
||||
deleteError: '删除渠道失败',
|
||||
nameRequired: '请输入渠道名称',
|
||||
duplicateModels: '模型「{0}」在多个定价条目中重复',
|
||||
modelConflict: "模型模式 '{model1}' 和 '{model2}' 冲突:匹配范围重叠",
|
||||
mappingConflict: "模型映射源 '{model1}' 和 '{model2}' 冲突:匹配范围重叠",
|
||||
deleteConfirm: '确定要删除渠道「{name}」吗?此操作不可撤销。',
|
||||
columns: {
|
||||
name: '名称',
|
||||
description: '描述',
|
||||
status: '状态',
|
||||
groups: '分组',
|
||||
pricing: '定价',
|
||||
createdAt: '创建时间',
|
||||
actions: '操作'
|
||||
},
|
||||
billingMode: {
|
||||
token: 'Token',
|
||||
perRequest: '按次',
|
||||
image: '图片(按次)'
|
||||
},
|
||||
form: {
|
||||
name: '名称',
|
||||
namePlaceholder: '输入渠道名称',
|
||||
description: '描述',
|
||||
descriptionPlaceholder: '可选描述',
|
||||
status: '状态',
|
||||
groups: '关联分组',
|
||||
noGroupsAvailable: '暂无可用分组',
|
||||
inOtherChannel: '已属于「{name}」',
|
||||
modelPricing: '模型定价',
|
||||
models: '模型列表',
|
||||
modelsPlaceholder: '输入完整模型名后按回车添加',
|
||||
modelInputHint: '按回车添加,支持粘贴批量导入',
|
||||
billingMode: '计费模式',
|
||||
defaultPrices: '默认价格(未命中区间时使用)',
|
||||
inputPrice: '输入',
|
||||
outputPrice: '输出',
|
||||
cacheWritePrice: '缓存写入',
|
||||
cacheReadPrice: '缓存读取',
|
||||
imageTokenPrice: '图片输出',
|
||||
imageOutputPrice: '图片输出价格',
|
||||
pricePlaceholder: '默认',
|
||||
intervals: '上下文区间定价(可选)',
|
||||
addInterval: '添加区间',
|
||||
requestTiers: '按次计费层级',
|
||||
imageTiers: '图片计费层级(按次)',
|
||||
addTier: '添加层级',
|
||||
noTiersYet: '暂无层级,点击添加配置按次计费价格',
|
||||
noPricingRules: '暂无定价规则,点击"添加"创建',
|
||||
perRequestPrice: '单次价格',
|
||||
perRequestPriceRequired: '按次/图片计费模式必须设置默认价格或至少一个计费层级',
|
||||
tierLabel: '层级',
|
||||
resolution: '分辨率',
|
||||
modelMapping: '模型映射',
|
||||
modelMappingHint: '将请求中的模型名映射为实际模型名。在账号级别映射之前执行。',
|
||||
noMappingRules: '暂无映射规则,点击"添加"创建',
|
||||
mappingSource: '源模型',
|
||||
mappingTarget: '目标模型',
|
||||
billingModelSource: '计费基准',
|
||||
billingModelSourceChannelMapped: '以渠道映射后的模型计费',
|
||||
billingModelSourceRequested: '以请求模型计费',
|
||||
billingModelSourceUpstream: '以最终模型计费',
|
||||
billingModelSourceHint: '控制使用哪个模型名称进行定价查找',
|
||||
selectedCount: '已选 {count} 个',
|
||||
searchGroups: '搜索分组...',
|
||||
noGroupsMatch: '没有匹配的分组',
|
||||
restrictModels: '限制模型',
|
||||
restrictModelsHint: '开启后,仅允许模型定价列表中的模型。不在列表中的模型请求将被拒绝。',
|
||||
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
|
||||
defaultImagePrice: '默认图片价格(未命中层级时使用)',
|
||||
platformConfig: '平台配置',
|
||||
basicSettings: '基础设置',
|
||||
addPlatform: '添加平台',
|
||||
noPlatforms: '点击"添加平台"开始配置渠道',
|
||||
mappingCount: '条映射',
|
||||
pricingEntry: '定价配置',
|
||||
noModels: '未添加模型'
|
||||
}
|
||||
},
|
||||
|
||||
// Subscriptions Management
|
||||
subscriptions: {
|
||||
title: '订阅管理',
|
||||
@@ -3417,6 +3519,11 @@ export default {
|
||||
allBillingTypes: '全部计费类型',
|
||||
billingTypeBalance: '钱包余额',
|
||||
billingTypeSubscription: '订阅套餐',
|
||||
billingMode: '计费模式',
|
||||
billingModeToken: '按量',
|
||||
billingModePerRequest: '按次',
|
||||
billingModeImage: '按次(图片)',
|
||||
allBillingModes: '全部计费模式',
|
||||
ipAddress: 'IP',
|
||||
clickToViewBalance: '点击查看充值记录',
|
||||
failedToLoadUser: '加载用户信息失败',
|
||||
|
||||
@@ -278,6 +278,18 @@ const routes: RouteRecordRaw[] = [
|
||||
descriptionKey: 'admin.groups.description'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/channels',
|
||||
name: 'AdminChannels',
|
||||
component: () => import('@/views/admin/ChannelsView.vue'),
|
||||
meta: {
|
||||
requiresAuth: true,
|
||||
requiresAdmin: true,
|
||||
title: 'Channel Management',
|
||||
titleKey: 'admin.channels.title',
|
||||
descriptionKey: 'admin.channels.description'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/subscriptions',
|
||||
name: 'AdminSubscriptions',
|
||||
|
||||
@@ -1036,6 +1036,9 @@ export interface UsageLog {
|
||||
// Cache TTL Override
|
||||
cache_ttl_overridden: boolean
|
||||
|
||||
// 计费模式
|
||||
billing_mode?: string | null
|
||||
|
||||
created_at: string
|
||||
|
||||
user?: User
|
||||
@@ -1051,6 +1054,7 @@ export interface UsageLogAccountSummary {
|
||||
|
||||
export interface AdminUsageLog extends UsageLog {
|
||||
upstream_model?: string | null
|
||||
model_mapping_chain?: string | null
|
||||
|
||||
// 账号计费倍率(仅管理员可见)
|
||||
account_rate_multiplier?: number | null
|
||||
|
||||
18
frontend/src/utils/formatters.ts
Normal file
18
frontend/src/utils/formatters.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/**
|
||||
* 格式化缓存 token 数量(1K/1M 缩写)
|
||||
*/
|
||||
export function formatCacheTokens(tokens: number): string {
|
||||
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
|
||||
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
|
||||
return tokens.toLocaleString()
|
||||
}
|
||||
|
||||
/**
|
||||
* 自适应精度格式化倍率(确保小数值如 0.001 不被截断)
|
||||
*/
|
||||
export function formatMultiplier(val: number): string {
|
||||
if (val >= 0.01) return val.toFixed(2)
|
||||
if (val >= 0.001) return val.toFixed(3)
|
||||
if (val >= 0.0001) return val.toFixed(4)
|
||||
return val.toPrecision(2)
|
||||
}
|
||||
1065
frontend/src/views/admin/ChannelsView.vue
Normal file
1065
frontend/src/views/admin/ChannelsView.vue
Normal file
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@
|
||||
:show-metric-toggle="true"
|
||||
:start-date="startDate"
|
||||
:end-date="endDate"
|
||||
:filters="breakdownFilters"
|
||||
/>
|
||||
<GroupDistributionChart
|
||||
v-model:metric="groupDistributionMetric"
|
||||
@@ -42,6 +43,7 @@
|
||||
:show-metric-toggle="true"
|
||||
:start-date="startDate"
|
||||
:end-date="endDate"
|
||||
:filters="breakdownFilters"
|
||||
/>
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
||||
@@ -57,6 +59,7 @@
|
||||
:title="t('usage.endpointDistribution')"
|
||||
:start-date="startDate"
|
||||
:end-date="endDate"
|
||||
:filters="breakdownFilters"
|
||||
/>
|
||||
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
|
||||
</div>
|
||||
@@ -169,6 +172,17 @@ const cleanupDialogVisible = ref(false)
|
||||
const showBalanceHistoryModal = ref(false)
|
||||
const balanceHistoryUser = ref<AdminUser | null>(null)
|
||||
|
||||
const breakdownFilters = computed(() => {
|
||||
const f: Record<string, any> = {}
|
||||
if (filters.value.user_id) f.user_id = filters.value.user_id
|
||||
if (filters.value.api_key_id) f.api_key_id = filters.value.api_key_id
|
||||
if (filters.value.account_id) f.account_id = filters.value.account_id
|
||||
if (filters.value.group_id) f.group_id = filters.value.group_id
|
||||
if (filters.value.request_type != null) f.request_type = filters.value.request_type
|
||||
if (filters.value.billing_type != null) f.billing_type = filters.value.billing_type
|
||||
return f
|
||||
})
|
||||
|
||||
const handleUserClick = async (userId: number) => {
|
||||
try {
|
||||
const user = await adminAPI.users.getById(userId)
|
||||
@@ -392,7 +406,7 @@ const resetFilters = () => {
|
||||
const range = getLast24HoursRangeDates()
|
||||
startDate.value = range.start
|
||||
endDate.value = range.end
|
||||
filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }
|
||||
filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null, billing_mode: undefined }
|
||||
granularity.value = getGranularityForRange(startDate.value, endDate.value)
|
||||
applyFilters()
|
||||
}
|
||||
@@ -440,7 +454,7 @@ const exportToExcel = async () => {
|
||||
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
|
||||
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
|
||||
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
|
||||
log.rate_multiplier?.toFixed(2) || '1.00', (log.account_rate_multiplier ?? 1).toFixed(2),
|
||||
log.rate_multiplier?.toPrecision(4) || '1.00', (log.account_rate_multiplier ?? 1).toPrecision(4),
|
||||
log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
|
||||
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms,
|
||||
log.request_id || '', log.user_agent || '', log.ip_address || ''
|
||||
@@ -477,6 +491,7 @@ const allColumns = computed(() => [
|
||||
{ key: 'endpoint', label: t('usage.endpoint'), sortable: false },
|
||||
{ key: 'group', label: t('admin.usage.group'), sortable: false },
|
||||
{ key: 'stream', label: t('usage.type'), sortable: false },
|
||||
{ key: 'billing_mode', label: t('admin.usage.billingMode'), sortable: false },
|
||||
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
||||
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
||||
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user