Merge pull request #1455 from touwaeriol/feat/channel-management

feat(channel): add channel management with multi-mode pricing and billing integration
This commit is contained in:
Wesley Liddick
2026-04-04 23:42:33 +08:00
committed by GitHub
101 changed files with 12458 additions and 550 deletions

View File

@@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
channelRepository := repository.NewChannelRepository(db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
@@ -138,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
@@ -175,9 +176,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@@ -213,7 +216,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler)
channelHandler := admin.NewChannelHandler(channelService, billingService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)

View File

@@ -744,6 +744,10 @@ var (
{Name: "model", Type: field.TypeString, Size: 100},
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "channel_id", Type: field.TypeInt64, Nullable: true},
{Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500},
{Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50},
{Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@@ -783,31 +787,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[35]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[36]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[33]},
Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[34]},
Columns: []*schema.Column{UsageLogsColumns[38]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -816,32 +820,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33]},
Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[34]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[35]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[34]},
Columns: []*schema.Column{UsageLogsColumns[38]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_model",
@@ -861,17 +865,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]},
},
},
}

View File

@@ -19725,6 +19725,11 @@ type UsageLogMutation struct {
model *string
requested_model *string
upstream_model *string
channel_id *int64
addchannel_id *int64
model_mapping_chain *string
billing_tier *string
billing_mode *string
input_tokens *int
addinput_tokens *int
output_tokens *int
@@ -20160,6 +20165,223 @@ func (m *UsageLogMutation) ResetUpstreamModel() {
delete(m.clearedFields, usagelog.FieldUpstreamModel)
}
// SetChannelID sets the "channel_id" field.
func (m *UsageLogMutation) SetChannelID(i int64) {
m.channel_id = &i
m.addchannel_id = nil
}
// ChannelID returns the value of the "channel_id" field in the mutation.
func (m *UsageLogMutation) ChannelID() (r int64, exists bool) {
v := m.channel_id
if v == nil {
return
}
return *v, true
}
// OldChannelID returns the old "channel_id" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldChannelID(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldChannelID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldChannelID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldChannelID: %w", err)
}
return oldValue.ChannelID, nil
}
// AddChannelID adds i to the "channel_id" field.
func (m *UsageLogMutation) AddChannelID(i int64) {
if m.addchannel_id != nil {
*m.addchannel_id += i
} else {
m.addchannel_id = &i
}
}
// AddedChannelID returns the value that was added to the "channel_id" field in this mutation.
func (m *UsageLogMutation) AddedChannelID() (r int64, exists bool) {
v := m.addchannel_id
if v == nil {
return
}
return *v, true
}
// ClearChannelID clears the value of the "channel_id" field.
func (m *UsageLogMutation) ClearChannelID() {
m.channel_id = nil
m.addchannel_id = nil
m.clearedFields[usagelog.FieldChannelID] = struct{}{}
}
// ChannelIDCleared returns if the "channel_id" field was cleared in this mutation.
func (m *UsageLogMutation) ChannelIDCleared() bool {
_, ok := m.clearedFields[usagelog.FieldChannelID]
return ok
}
// ResetChannelID resets all changes to the "channel_id" field.
func (m *UsageLogMutation) ResetChannelID() {
m.channel_id = nil
m.addchannel_id = nil
delete(m.clearedFields, usagelog.FieldChannelID)
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (m *UsageLogMutation) SetModelMappingChain(s string) {
m.model_mapping_chain = &s
}
// ModelMappingChain returns the value of the "model_mapping_chain" field in the mutation.
func (m *UsageLogMutation) ModelMappingChain() (r string, exists bool) {
v := m.model_mapping_chain
if v == nil {
return
}
return *v, true
}
// OldModelMappingChain returns the old "model_mapping_chain" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldModelMappingChain(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldModelMappingChain is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldModelMappingChain requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldModelMappingChain: %w", err)
}
return oldValue.ModelMappingChain, nil
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (m *UsageLogMutation) ClearModelMappingChain() {
m.model_mapping_chain = nil
m.clearedFields[usagelog.FieldModelMappingChain] = struct{}{}
}
// ModelMappingChainCleared returns if the "model_mapping_chain" field was cleared in this mutation.
func (m *UsageLogMutation) ModelMappingChainCleared() bool {
_, ok := m.clearedFields[usagelog.FieldModelMappingChain]
return ok
}
// ResetModelMappingChain resets all changes to the "model_mapping_chain" field.
func (m *UsageLogMutation) ResetModelMappingChain() {
m.model_mapping_chain = nil
delete(m.clearedFields, usagelog.FieldModelMappingChain)
}
// SetBillingTier sets the "billing_tier" field.
func (m *UsageLogMutation) SetBillingTier(s string) {
m.billing_tier = &s
}
// BillingTier returns the value of the "billing_tier" field in the mutation.
func (m *UsageLogMutation) BillingTier() (r string, exists bool) {
v := m.billing_tier
if v == nil {
return
}
return *v, true
}
// OldBillingTier returns the old "billing_tier" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldBillingTier(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBillingTier is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBillingTier requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBillingTier: %w", err)
}
return oldValue.BillingTier, nil
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (m *UsageLogMutation) ClearBillingTier() {
m.billing_tier = nil
m.clearedFields[usagelog.FieldBillingTier] = struct{}{}
}
// BillingTierCleared returns if the "billing_tier" field was cleared in this mutation.
func (m *UsageLogMutation) BillingTierCleared() bool {
_, ok := m.clearedFields[usagelog.FieldBillingTier]
return ok
}
// ResetBillingTier resets all changes to the "billing_tier" field.
func (m *UsageLogMutation) ResetBillingTier() {
m.billing_tier = nil
delete(m.clearedFields, usagelog.FieldBillingTier)
}
// SetBillingMode sets the "billing_mode" field.
func (m *UsageLogMutation) SetBillingMode(s string) {
m.billing_mode = &s
}
// BillingMode returns the value of the "billing_mode" field in the mutation.
func (m *UsageLogMutation) BillingMode() (r string, exists bool) {
v := m.billing_mode
if v == nil {
return
}
return *v, true
}
// OldBillingMode returns the old "billing_mode" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldBillingMode(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBillingMode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBillingMode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBillingMode: %w", err)
}
return oldValue.BillingMode, nil
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (m *UsageLogMutation) ClearBillingMode() {
m.billing_mode = nil
m.clearedFields[usagelog.FieldBillingMode] = struct{}{}
}
// BillingModeCleared returns if the "billing_mode" field was cleared in this mutation.
func (m *UsageLogMutation) BillingModeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldBillingMode]
return ok
}
// ResetBillingMode resets all changes to the "billing_mode" field.
func (m *UsageLogMutation) ResetBillingMode() {
m.billing_mode = nil
delete(m.clearedFields, usagelog.FieldBillingMode)
}
// SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i
@@ -21781,7 +22003,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 34)
fields := make([]string, 0, 38)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -21803,6 +22025,18 @@ func (m *UsageLogMutation) Fields() []string {
if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.channel_id != nil {
fields = append(fields, usagelog.FieldChannelID)
}
if m.model_mapping_chain != nil {
fields = append(fields, usagelog.FieldModelMappingChain)
}
if m.billing_tier != nil {
fields = append(fields, usagelog.FieldBillingTier)
}
if m.billing_mode != nil {
fields = append(fields, usagelog.FieldBillingMode)
}
if m.group != nil {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -21906,6 +22140,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestedModel()
case usagelog.FieldUpstreamModel:
return m.UpstreamModel()
case usagelog.FieldChannelID:
return m.ChannelID()
case usagelog.FieldModelMappingChain:
return m.ModelMappingChain()
case usagelog.FieldBillingTier:
return m.BillingTier()
case usagelog.FieldBillingMode:
return m.BillingMode()
case usagelog.FieldGroupID:
return m.GroupID()
case usagelog.FieldSubscriptionID:
@@ -21983,6 +22225,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestedModel(ctx)
case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx)
case usagelog.FieldChannelID:
return m.OldChannelID(ctx)
case usagelog.FieldModelMappingChain:
return m.OldModelMappingChain(ctx)
case usagelog.FieldBillingTier:
return m.OldBillingTier(ctx)
case usagelog.FieldBillingMode:
return m.OldBillingMode(ctx)
case usagelog.FieldGroupID:
return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID:
@@ -22095,6 +22345,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetUpstreamModel(v)
return nil
case usagelog.FieldChannelID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetChannelID(v)
return nil
case usagelog.FieldModelMappingChain:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetModelMappingChain(v)
return nil
case usagelog.FieldBillingTier:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBillingTier(v)
return nil
case usagelog.FieldBillingMode:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBillingMode(v)
return nil
case usagelog.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -22292,6 +22570,9 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
// this mutation.
func (m *UsageLogMutation) AddedFields() []string {
var fields []string
if m.addchannel_id != nil {
fields = append(fields, usagelog.FieldChannelID)
}
if m.addinput_tokens != nil {
fields = append(fields, usagelog.FieldInputTokens)
}
@@ -22354,6 +22635,8 @@ func (m *UsageLogMutation) AddedFields() []string {
// was not set, or was not defined in the schema.
func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
switch name {
case usagelog.FieldChannelID:
return m.AddedChannelID()
case usagelog.FieldInputTokens:
return m.AddedInputTokens()
case usagelog.FieldOutputTokens:
@@ -22399,6 +22682,13 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
// type.
func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
switch name {
case usagelog.FieldChannelID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddChannelID(v)
return nil
case usagelog.FieldInputTokens:
v, ok := value.(int)
if !ok {
@@ -22539,6 +22829,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.FieldCleared(usagelog.FieldChannelID) {
fields = append(fields, usagelog.FieldChannelID)
}
if m.FieldCleared(usagelog.FieldModelMappingChain) {
fields = append(fields, usagelog.FieldModelMappingChain)
}
if m.FieldCleared(usagelog.FieldBillingTier) {
fields = append(fields, usagelog.FieldBillingTier)
}
if m.FieldCleared(usagelog.FieldBillingMode) {
fields = append(fields, usagelog.FieldBillingMode)
}
if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -22586,6 +22888,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel()
return nil
case usagelog.FieldChannelID:
m.ClearChannelID()
return nil
case usagelog.FieldModelMappingChain:
m.ClearModelMappingChain()
return nil
case usagelog.FieldBillingTier:
m.ClearBillingTier()
return nil
case usagelog.FieldBillingMode:
m.ClearBillingMode()
return nil
case usagelog.FieldGroupID:
m.ClearGroupID()
return nil
@@ -22642,6 +22956,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel()
return nil
case usagelog.FieldChannelID:
m.ResetChannelID()
return nil
case usagelog.FieldModelMappingChain:
m.ResetModelMappingChain()
return nil
case usagelog.FieldBillingTier:
m.ResetBillingTier()
return nil
case usagelog.FieldBillingMode:
m.ResetBillingMode()
return nil
case usagelog.FieldGroupID:
m.ResetGroupID()
return nil

View File

@@ -875,92 +875,104 @@ func init() {
usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field.
usagelogDescModelMappingChain := usagelogFields[8].Descriptor()
// usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error)
// usagelogDescBillingTier is the schema descriptor for billing_tier field.
usagelogDescBillingTier := usagelogFields[9].Descriptor()
// usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error)
// usagelogDescBillingMode is the schema descriptor for billing_mode field.
usagelogDescBillingMode := usagelogFields[10].Descriptor()
// usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[9].Descriptor()
usagelogDescInputTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[10].Descriptor()
usagelogDescOutputTokens := usagelogFields[14].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor()
usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[12].Descriptor()
usagelogDescCacheReadTokens := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor()
usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor()
usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[15].Descriptor()
usagelogDescInputCost := usagelogFields[19].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[16].Descriptor()
usagelogDescOutputCost := usagelogFields[20].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[17].Descriptor()
usagelogDescCacheCreationCost := usagelogFields[21].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[18].Descriptor()
usagelogDescCacheReadCost := usagelogFields[22].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[19].Descriptor()
usagelogDescTotalCost := usagelogFields[23].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[20].Descriptor()
usagelogDescActualCost := usagelogFields[24].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[21].Descriptor()
usagelogDescRateMultiplier := usagelogFields[25].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[23].Descriptor()
usagelogDescBillingType := usagelogFields[27].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[24].Descriptor()
usagelogDescStream := usagelogFields[28].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[27].Descriptor()
usagelogDescUserAgent := usagelogFields[31].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[28].Descriptor()
usagelogDescIPAddress := usagelogFields[32].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[29].Descriptor()
usagelogDescImageCount := usagelogFields[33].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[30].Descriptor()
usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[31].Descriptor()
usagelogDescMediaType := usagelogFields[35].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[33].Descriptor()
usagelogDescCreatedAt := usagelogFields[37].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -53,6 +53,10 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(100).
Optional().
Nillable(),
field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"),
field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式token/per_request/image"),
field.Int64("group_id").
Optional().
Nillable(),

View File

@@ -36,6 +36,14 @@ type UsageLog struct {
RequestedModel *string `json:"requested_model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"`
// 渠道 ID
ChannelID *int64 `json:"channel_id,omitempty"`
// 模型映射链
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
// 计费层级标签
BillingTier *string `json:"billing_tier,omitempty"`
// 计费模式token/per_request/image
BillingMode *string `json:"billing_mode,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
@@ -177,9 +185,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -248,6 +256,34 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String
}
case usagelog.FieldChannelID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field channel_id", values[i])
} else if value.Valid {
_m.ChannelID = new(int64)
*_m.ChannelID = value.Int64
}
case usagelog.FieldModelMappingChain:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i])
} else if value.Valid {
_m.ModelMappingChain = new(string)
*_m.ModelMappingChain = value.String
}
case usagelog.FieldBillingTier:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field billing_tier", values[i])
} else if value.Valid {
_m.BillingTier = new(string)
*_m.BillingTier = value.String
}
case usagelog.FieldBillingMode:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field billing_mode", values[i])
} else if value.Valid {
_m.BillingMode = new(string)
*_m.BillingMode = value.String
}
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -505,6 +541,26 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ChannelID; v != nil {
builder.WriteString("channel_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.ModelMappingChain; v != nil {
builder.WriteString("model_mapping_chain=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.BillingTier; v != nil {
builder.WriteString("billing_tier=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.BillingMode; v != nil {
builder.WriteString("billing_mode=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@@ -28,6 +28,14 @@ const (
FieldRequestedModel = "requested_model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model"
// FieldChannelID holds the string denoting the channel_id field in the database.
FieldChannelID = "channel_id"
// FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database.
FieldModelMappingChain = "model_mapping_chain"
// FieldBillingTier holds the string denoting the billing_tier field in the database.
FieldBillingTier = "billing_tier"
// FieldBillingMode holds the string denoting the billing_mode field in the database.
FieldBillingMode = "billing_mode"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@@ -141,6 +149,10 @@ var Columns = []string{
FieldModel,
FieldRequestedModel,
FieldUpstreamModel,
FieldChannelID,
FieldModelMappingChain,
FieldBillingTier,
FieldBillingMode,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
@@ -189,6 +201,12 @@ var (
RequestedModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error
// ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
ModelMappingChainValidator func(string) error
// BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
BillingTierValidator func(string) error
// BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
BillingModeValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@@ -278,6 +296,26 @@ func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
}
// ByChannelID orders the results by the channel_id field.
func ByChannelID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldChannelID, opts...).ToFunc()
}
// ByModelMappingChain orders the results by the model_mapping_chain field.
func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc()
}
// ByBillingTier orders the results by the billing_tier field.
func ByBillingTier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingTier, opts...).ToFunc()
}
// ByBillingMode orders the results by the billing_mode field.
func ByBillingMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingMode, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@@ -90,6 +90,26 @@ func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ.
func ChannelID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
}
// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ.
func ModelMappingChain(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
}
// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ.
func BillingTier(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
}
// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ.
func BillingMode(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -565,6 +585,281 @@ func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
}
// ChannelIDEQ applies the EQ predicate on the "channel_id" field.
func ChannelIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
}
// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field.
func ChannelIDNEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v))
}
// ChannelIDIn applies the In predicate on the "channel_id" field.
func ChannelIDIn(vs ...int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...))
}
// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field.
func ChannelIDNotIn(vs ...int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...))
}
// ChannelIDGT applies the GT predicate on the "channel_id" field.
func ChannelIDGT(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldChannelID, v))
}
// ChannelIDGTE applies the GTE predicate on the "channel_id" field.
func ChannelIDGTE(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v))
}
// ChannelIDLT applies the LT predicate on the "channel_id" field.
func ChannelIDLT(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldChannelID, v))
}
// ChannelIDLTE applies the LTE predicate on the "channel_id" field.
func ChannelIDLTE(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v))
}
// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field.
func ChannelIDIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldChannelID))
}
// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field.
func ChannelIDNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldChannelID))
}
// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field.
func ModelMappingChainEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
}
// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field.
func ModelMappingChainNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v))
}
// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field.
func ModelMappingChainIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...))
}
// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field.
func ModelMappingChainNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...))
}
// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field.
func ModelMappingChainGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v))
}
// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field.
func ModelMappingChainGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v))
}
// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field.
func ModelMappingChainLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v))
}
// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field.
func ModelMappingChainLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v))
}
// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field.
func ModelMappingChainContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v))
}
// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field.
func ModelMappingChainHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v))
}
// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field.
func ModelMappingChainHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v))
}
// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field.
func ModelMappingChainIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain))
}
// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field.
func ModelMappingChainNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain))
}
// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field.
func ModelMappingChainEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v))
}
// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field.
func ModelMappingChainContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v))
}
// BillingTierEQ applies the EQ predicate on the "billing_tier" field.
func BillingTierEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
}
// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field.
func BillingTierNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v))
}
// BillingTierIn applies the In predicate on the "billing_tier" field.
func BillingTierIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...))
}
// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field.
func BillingTierNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...))
}
// BillingTierGT applies the GT predicate on the "billing_tier" field.
func BillingTierGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v))
}
// BillingTierGTE applies the GTE predicate on the "billing_tier" field.
func BillingTierGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v))
}
// BillingTierLT applies the LT predicate on the "billing_tier" field.
func BillingTierLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v))
}
// BillingTierLTE applies the LTE predicate on the "billing_tier" field.
func BillingTierLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v))
}
// BillingTierContains applies the Contains predicate on the "billing_tier" field.
func BillingTierContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v))
}
// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field.
func BillingTierHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v))
}
// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field.
func BillingTierHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v))
}
// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field.
func BillingTierIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier))
}
// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field.
func BillingTierNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier))
}
// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field.
func BillingTierEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v))
}
// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field.
func BillingTierContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v))
}
// BillingModeEQ applies the EQ predicate on the "billing_mode" field.
func BillingModeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
}
// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field.
func BillingModeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v))
}
// BillingModeIn applies the In predicate on the "billing_mode" field.
func BillingModeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...))
}
// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field.
func BillingModeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...))
}
// BillingModeGT applies the GT predicate on the "billing_mode" field.
func BillingModeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v))
}
// BillingModeGTE applies the GTE predicate on the "billing_mode" field.
func BillingModeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v))
}
// BillingModeLT applies the LT predicate on the "billing_mode" field.
func BillingModeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v))
}
// BillingModeLTE applies the LTE predicate on the "billing_mode" field.
func BillingModeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v))
}
// BillingModeContains applies the Contains predicate on the "billing_mode" field.
func BillingModeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v))
}
// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field.
func BillingModeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v))
}
// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field.
func BillingModeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v))
}
// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field.
func BillingModeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode))
}
// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field.
func BillingModeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode))
}
// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field.
func BillingModeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v))
}
// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field.
func BillingModeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))

View File

@@ -85,6 +85,62 @@ func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
return _c
}
// SetChannelID sets the "channel_id" field.
func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate {
_c.mutation.SetChannelID(v)
return _c
}
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate {
if v != nil {
_c.SetChannelID(*v)
}
return _c
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate {
_c.mutation.SetModelMappingChain(v)
return _c
}
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate {
if v != nil {
_c.SetModelMappingChain(*v)
}
return _c
}
// SetBillingTier sets the "billing_tier" field.
func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate {
_c.mutation.SetBillingTier(v)
return _c
}
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate {
if v != nil {
_c.SetBillingTier(*v)
}
return _c
}
// SetBillingMode sets the "billing_mode" field.
func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate {
_c.mutation.SetBillingMode(v)
return _c
}
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate {
if v != nil {
_c.SetBillingMode(*v)
}
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v)
@@ -634,6 +690,21 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _c.mutation.ModelMappingChain(); ok {
if err := usagelog.ModelMappingChainValidator(v); err != nil {
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
}
}
if v, ok := _c.mutation.BillingTier(); ok {
if err := usagelog.BillingTierValidator(v); err != nil {
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
}
}
if v, ok := _c.mutation.BillingMode(); ok {
if err := usagelog.BillingModeValidator(v); err != nil {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
}
@@ -760,6 +831,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value
}
if value, ok := _c.mutation.ChannelID(); ok {
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
_node.ChannelID = &value
}
if value, ok := _c.mutation.ModelMappingChain(); ok {
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
_node.ModelMappingChain = &value
}
if value, ok := _c.mutation.BillingTier(); ok {
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
_node.BillingTier = &value
}
if value, ok := _c.mutation.BillingMode(); ok {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
_node.BillingMode = &value
}
if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value
@@ -1093,6 +1180,84 @@ func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
return u
}
// SetChannelID sets the "channel_id" field.
func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldChannelID, v)
return u
}
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldChannelID)
return u
}
// AddChannelID adds v to the "channel_id" field.
func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert {
u.Add(usagelog.FieldChannelID, v)
return u
}
// ClearChannelID clears the value of the "channel_id" field.
func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert {
u.SetNull(usagelog.FieldChannelID)
return u
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert {
u.Set(usagelog.FieldModelMappingChain, v)
return u
}
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldModelMappingChain)
return u
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert {
u.SetNull(usagelog.FieldModelMappingChain)
return u
}
// SetBillingTier sets the "billing_tier" field.
func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert {
u.Set(usagelog.FieldBillingTier, v)
return u
}
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldBillingTier)
return u
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert {
u.SetNull(usagelog.FieldBillingTier)
return u
}
// SetBillingMode sets the "billing_mode" field.
func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert {
u.Set(usagelog.FieldBillingMode, v)
return u
}
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldBillingMode)
return u
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert {
u.SetNull(usagelog.FieldBillingMode)
return u
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v)
@@ -1724,6 +1889,97 @@ func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
})
}
// SetChannelID sets the "channel_id" field.
func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetChannelID(v)
})
}
// AddChannelID adds v to the "channel_id" field.
func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.AddChannelID(v)
})
}
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateChannelID()
})
}
// ClearChannelID clears the value of the "channel_id" field.
func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearChannelID()
})
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetModelMappingChain(v)
})
}
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateModelMappingChain()
})
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearModelMappingChain()
})
}
// SetBillingTier sets the "billing_tier" field.
func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingTier(v)
})
}
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingTier()
})
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingTier()
})
}
// SetBillingMode sets the "billing_mode" field.
func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingMode(v)
})
}
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingMode()
})
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingMode()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2600,6 +2856,97 @@ func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
})
}
// SetChannelID sets the "channel_id" field.
func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetChannelID(v)
})
}
// AddChannelID adds v to the "channel_id" field.
func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.AddChannelID(v)
})
}
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateChannelID()
})
}
// ClearChannelID clears the value of the "channel_id" field.
func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearChannelID()
})
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetModelMappingChain(v)
})
}
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateModelMappingChain()
})
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearModelMappingChain()
})
}
// SetBillingTier sets the "billing_tier" field.
func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingTier(v)
})
}
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingTier()
})
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingTier()
})
}
// SetBillingMode sets the "billing_mode" field.
func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingMode(v)
})
}
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingMode()
})
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingMode()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -142,6 +142,93 @@ func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
return _u
}
// SetChannelID sets the "channel_id" field.
func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate {
_u.mutation.ResetChannelID()
_u.mutation.SetChannelID(v)
return _u
}
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate {
if v != nil {
_u.SetChannelID(*v)
}
return _u
}
// AddChannelID adds value to the "channel_id" field.
func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate {
_u.mutation.AddChannelID(v)
return _u
}
// ClearChannelID clears the value of the "channel_id" field.
func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate {
_u.mutation.ClearChannelID()
return _u
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate {
_u.mutation.SetModelMappingChain(v)
return _u
}
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate {
if v != nil {
_u.SetModelMappingChain(*v)
}
return _u
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate {
_u.mutation.ClearModelMappingChain()
return _u
}
// SetBillingTier sets the "billing_tier" field.
func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate {
_u.mutation.SetBillingTier(v)
return _u
}
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate {
if v != nil {
_u.SetBillingTier(*v)
}
return _u
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate {
_u.mutation.ClearBillingTier()
return _u
}
// SetBillingMode sets the "billing_mode" field.
func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate {
_u.mutation.SetBillingMode(v)
return _u
}
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate {
if v != nil {
_u.SetBillingMode(*v)
}
return _u
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate {
_u.mutation.ClearBillingMode()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v)
@@ -795,6 +882,21 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.ModelMappingChain(); ok {
if err := usagelog.ModelMappingChainValidator(v); err != nil {
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
}
}
if v, ok := _u.mutation.BillingTier(); ok {
if err := usagelog.BillingTierValidator(v); err != nil {
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
}
}
if v, ok := _u.mutation.BillingMode(); ok {
if err := usagelog.BillingModeValidator(v); err != nil {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -857,6 +959,33 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.ChannelID(); ok {
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedChannelID(); ok {
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if _u.mutation.ChannelIDCleared() {
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelMappingChain(); ok {
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
}
if _u.mutation.ModelMappingChainCleared() {
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
}
if value, ok := _u.mutation.BillingTier(); ok {
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
}
if _u.mutation.BillingTierCleared() {
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
}
if value, ok := _u.mutation.BillingMode(); ok {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
}
if _u.mutation.BillingModeCleared() {
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -1279,6 +1408,93 @@ func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
return _u
}
// SetChannelID sets the "channel_id" field.
func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne {
_u.mutation.ResetChannelID()
_u.mutation.SetChannelID(v)
return _u
}
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne {
if v != nil {
_u.SetChannelID(*v)
}
return _u
}
// AddChannelID adds value to the "channel_id" field.
func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne {
_u.mutation.AddChannelID(v)
return _u
}
// ClearChannelID clears the value of the "channel_id" field.
func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne {
_u.mutation.ClearChannelID()
return _u
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne {
_u.mutation.SetModelMappingChain(v)
return _u
}
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetModelMappingChain(*v)
}
return _u
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne {
_u.mutation.ClearModelMappingChain()
return _u
}
// SetBillingTier sets the "billing_tier" field.
func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne {
_u.mutation.SetBillingTier(v)
return _u
}
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetBillingTier(*v)
}
return _u
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne {
_u.mutation.ClearBillingTier()
return _u
}
// SetBillingMode sets the "billing_mode" field.
func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne {
_u.mutation.SetBillingMode(v)
return _u
}
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetBillingMode(*v)
}
return _u
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne {
_u.mutation.ClearBillingMode()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v)
@@ -1945,6 +2161,21 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.ModelMappingChain(); ok {
if err := usagelog.ModelMappingChainValidator(v); err != nil {
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
}
}
if v, ok := _u.mutation.BillingTier(); ok {
if err := usagelog.BillingTierValidator(v); err != nil {
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
}
}
if v, ok := _u.mutation.BillingMode(); ok {
if err := usagelog.BillingModeValidator(v); err != nil {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -2024,6 +2255,33 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.ChannelID(); ok {
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedChannelID(); ok {
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if _u.mutation.ChannelIDCleared() {
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelMappingChain(); ok {
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
}
if _u.mutation.ModelMappingChainCleared() {
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
}
if value, ok := _u.mutation.BillingTier(); ok {
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
}
if _u.mutation.BillingTierCleared() {
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
}
if value, ok := _u.mutation.BillingMode(); ok {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
}
if _u.mutation.BillingModeCleared() {
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}

View File

@@ -0,0 +1,452 @@
package admin
import (
"errors"
"fmt"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChannelHandler handles admin channel management
type ChannelHandler struct {
channelService *service.ChannelService
billingService *service.BillingService
}
// NewChannelHandler creates a new admin channel handler
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
return &ChannelHandler{channelService: channelService, billingService: billingService}
}
// --- Request / Response types ---
type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
}
type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
}
type channelModelPricingRequest struct {
Platform string `json:"platform" binding:"omitempty,max=50"`
Models []string `json:"models" binding:"required,min=1,max=100"`
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
Intervals []pricingIntervalRequest `json:"intervals"`
}
type pricingIntervalRequest struct {
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
SortOrder int `json:"sort_order"`
}
type channelResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
ID int64 `json:"id"`
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []pricingIntervalResponse `json:"intervals"`
}
type pricingIntervalResponse struct {
ID int64 `json:"id"`
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
SortOrder int `json:"sort_order"`
}
func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil {
return nil
}
resp := &channelResponse{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
RestrictModels: ch.RestrictModels,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = service.BillingModelSourceChannelMapped
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
if resp.ModelMapping == nil {
resp.ModelMapping = map[string]map[string]string{}
}
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
}
return resp
}
func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse {
models := p.Models
if models == nil {
models = []string{}
}
billingMode := string(p.BillingMode)
if billingMode == "" {
billingMode = string(service.BillingModeToken)
}
platform := p.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
for _, iv := range p.Intervals {
intervals = append(intervals, intervalToResponse(iv))
}
return channelModelPricingResponse{
ID: p.ID,
Platform: platform,
Models: models,
BillingMode: billingMode,
InputPrice: p.InputPrice,
OutputPrice: p.OutputPrice,
CacheWritePrice: p.CacheWritePrice,
CacheReadPrice: p.CacheReadPrice,
ImageOutputPrice: p.ImageOutputPrice,
PerRequestPrice: p.PerRequestPrice,
Intervals: intervals,
}
}
func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse {
return pricingIntervalResponse{
ID: iv.ID,
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
SortOrder: iv.SortOrder,
}
}
func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
result := make([]service.ChannelModelPricing, 0, len(reqs))
for _, r := range reqs {
billingMode := service.BillingMode(r.BillingMode)
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := r.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
SortOrder: iv.SortOrder,
})
}
result = append(result, service.ChannelModelPricing{
Platform: platform,
Models: r.Models,
BillingMode: billingMode,
InputPrice: r.InputPrice,
OutputPrice: r.OutputPrice,
CacheWritePrice: r.CacheWritePrice,
CacheReadPrice: r.CacheReadPrice,
ImageOutputPrice: r.ImageOutputPrice,
PerRequestPrice: r.PerRequestPrice,
Intervals: intervals,
})
}
return result
}
// validatePricingBillingMode 校验计费配置
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
for _, p := range pricing {
// 按次/图片模式必须配置默认价格或区间
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return errors.New("per-request price or intervals required for per_request/image billing mode")
}
}
// 校验价格不能为负
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
return err
}
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
return err
}
// 校验 interval至少有一个价格字段非空
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
}
}
}
return nil
}
func validatePriceNotNegative(field string, val *float64) error {
if val != nil && *val < 0 {
return fmt.Errorf("%s must be >= 0", field)
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- Handlers ---
// List handles listing channels with pagination
// GET /api/v1/admin/channels
func (h *ChannelHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*channelResponse, 0, len(channels))
for i := range channels {
out = append(out, channelToResponse(&channels[i]))
}
response.Paginated(c, out, pag.Total, page, pageSize)
}
// GetByID handles getting a channel by ID
// GET /api/v1/admin/channels/:id
func (h *ChannelHandler) GetByID(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
return
}
channel, err := h.channelService.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelToResponse(channel))
}
// Create handles creating a new channel
// POST /api/v1/admin/channels
func (h *ChannelHandler) Create(c *gin.Context) {
var req createChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
pricing := pricingRequestToService(req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name,
Description: req.Description,
GroupIDs: req.GroupIDs,
ModelPricing: pricing,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelToResponse(channel))
}
// Update handles updating a channel
// PUT /api/v1/admin/channels/:id
func (h *ChannelHandler) Update(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
return
}
var req updateChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input := &service.UpdateChannelInput{
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input.ModelPricing = &pricing
}
channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelToResponse(channel))
}
// Delete handles deleting a channel
// DELETE /api/v1/admin/channels/:id
func (h *ChannelHandler) Delete(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
return
}
if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Channel deleted successfully"})
}
// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充)
// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4
func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
model := strings.TrimSpace(c.Query("model"))
if model == "" {
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required").
WithMetadata(map[string]string{"param": "model"}))
return
}
pricing, err := h.billingService.GetModelPricing(model)
if err != nil {
// 模型不在定价列表中
response.Success(c, gin.H{"found": false})
return
}
response.Success(c, gin.H{
"found": true,
"input_price": pricing.InputPricePerToken,
"output_price": pricing.OutputPricePerToken,
"cache_write_price": pricing.CacheCreationPricePerToken,
"cache_read_price": pricing.CacheReadPricePerToken,
"image_output_price": pricing.ImageOutputPricePerToken,
})
}

View File

@@ -0,0 +1,502 @@
//go:build unit
package admin
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
func float64Ptr(v float64) *float64 { return &v }
func intPtr(v int) *int { return &v }
// ---------------------------------------------------------------------------
// 1. channelToResponse
// ---------------------------------------------------------------------------
func TestChannelToResponse_NilInput(t *testing.T) {
require.Nil(t, channelToResponse(nil))
}
func TestChannelToResponse_FullChannel(t *testing.T) {
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 42,
Name: "test-channel",
Description: "desc",
Status: "active",
BillingModelSource: "upstream",
RestrictModels: true,
CreatedAt: now,
UpdatedAt: now.Add(time.Hour),
GroupIDs: []int64{1, 2, 3},
ModelPricing: []service.ChannelModelPricing{
{
ID: 10,
Platform: "openai",
Models: []string{"gpt-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
PerRequestPrice: float64Ptr(0.5),
},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
},
}
resp := channelToResponse(ch)
require.NotNil(t, resp)
require.Equal(t, int64(42), resp.ID)
require.Equal(t, "test-channel", resp.Name)
require.Equal(t, "desc", resp.Description)
require.Equal(t, "active", resp.Status)
require.Equal(t, "upstream", resp.BillingModelSource)
require.True(t, resp.RestrictModels)
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
// model mapping
require.Len(t, resp.ModelMapping, 1)
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
// pricing
require.Len(t, resp.ModelPricing, 1)
p := resp.ModelPricing[0]
require.Equal(t, int64(10), p.ID)
require.Equal(t, "openai", p.Platform)
require.Equal(t, []string{"gpt-4"}, p.Models)
require.Equal(t, "token", p.BillingMode)
require.Equal(t, float64Ptr(0.01), p.InputPrice)
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
require.Empty(t, p.Intervals)
}
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
ModelMapping: nil,
ModelPricing: []service.ChannelModelPricing{
{
Platform: "",
BillingMode: "",
Models: []string{"m1"},
},
},
}
resp := channelToResponse(ch)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping)
require.Empty(t, resp.ModelMapping)
require.Len(t, resp.ModelPricing, 1)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: nil,
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
require.NotNil(t, resp.ModelPricing[0].Models)
require.Empty(t, resp.ModelPricing[0].Models)
}
func TestChannelToResponse_WithIntervals(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: []string{"m1"},
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{
ID: 100,
MinTokens: 0,
MaxTokens: intPtr(1000),
TierLabel: "1K",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
ID: 101,
MinTokens: 1000,
MaxTokens: nil,
TierLabel: "unlimited",
SortOrder: 2,
},
},
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
intervals := resp.ModelPricing[0].Intervals
require.Len(t, intervals, 2)
iv0 := intervals[0]
require.Equal(t, int64(100), iv0.ID)
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(1000), iv0.MaxTokens)
require.Equal(t, "1K", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := intervals[1]
require.Equal(t, int64(101), iv1.ID)
require.Equal(t, 1000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "unlimited", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestChannelToResponse_MultipleEntries(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "multi",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
ID: 1,
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.003),
OutputPrice: float64Ptr(0.015),
},
{
ID: 2,
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(1.0),
},
{
ID: 3,
Platform: "gemini",
Models: []string{"gemini-2.5-pro"},
BillingMode: service.BillingModeImage,
ImageOutputPrice: float64Ptr(0.05),
PerRequestPrice: float64Ptr(0.2),
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 3)
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
}
// ---------------------------------------------------------------------------
// 2. pricingRequestToService
// ---------------------------------------------------------------------------
func TestPricingRequestToService_Defaults(t *testing.T) {
tests := []struct {
name string
req channelModelPricingRequest
wantField string // which default field to check
wantValue string
}{
{
name: "empty billing mode defaults to token",
req: channelModelPricingRequest{
Models: []string{"m1"},
BillingMode: "",
},
wantField: "BillingMode",
wantValue: string(service.BillingModeToken),
},
{
name: "empty platform defaults to anthropic",
req: channelModelPricingRequest{
Models: []string{"m1"},
Platform: "",
},
wantField: "Platform",
wantValue: "anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
require.Len(t, result, 1)
switch tt.wantField {
case "BillingMode":
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
case "Platform":
require.Equal(t, tt.wantValue, result[0].Platform)
}
})
}
}
func TestPricingRequestToService_WithAllFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: "per_request",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
ImageOutputPrice: float64Ptr(0.04),
PerRequestPrice: float64Ptr(0.5),
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Equal(t, "openai", r.Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
require.Equal(t, float64Ptr(0.01), r.InputPrice)
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
}
func TestPricingRequestToService_WithIntervals(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "per_request",
Intervals: []pricingIntervalRequest{
{
MinTokens: 0,
MaxTokens: intPtr(2000),
TierLabel: "small",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
MinTokens: 2000,
MaxTokens: nil,
TierLabel: "large",
SortOrder: 2,
},
},
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
require.Len(t, result[0].Intervals, 2)
iv0 := result[0].Intervals[0]
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(2000), iv0.MaxTokens)
require.Equal(t, "small", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := result[0].Intervals[1]
require.Equal(t, 2000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "large", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestPricingRequestToService_EmptySlice(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{})
require.NotNil(t, result)
require.Empty(t, result)
}
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "token",
// all price fields are nil by default
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Nil(t, r.InputPrice)
require.Nil(t, r.OutputPrice)
require.Nil(t, r.CacheWritePrice)
require.Nil(t, r.CacheReadPrice)
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
// Additional filter conditions
if v := c.Query("user_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.UserID = id
}
}
if v := c.Query("api_key_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.APIKeyID = id
}
}
if v := c.Query("account_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.AccountID = id
}
}
if v := c.Query("request_type"); v != "" {
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
rtVal := int16(rt)
dim.RequestType = &rtVal
}
}
if v := c.Query("stream"); v != "" {
if s, err := strconv.ParseBool(v); err == nil {
dim.Stream = &s
}
}
if v := c.Query("billing_type"); v != "" {
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
btVal := int8(bt)
dim.BillingType = &btVal
}
}
limit := 50
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {

View File

@@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) {
}
model := c.Query("model")
billingMode := strings.TrimSpace(c.Query("billing_mode"))
var requestType *int16
var stream *bool
@@ -174,6 +175,7 @@ func (h *UsageHandler) List(c *gin.Context) {
RequestType: requestType,
Stream: stream,
BillingType: billingType,
BillingMode: billingMode,
StartTime: startTime,
EndTime: endTime,
ExactTotal: exactTotal,
@@ -234,6 +236,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
model := c.Query("model")
billingMode := strings.TrimSpace(c.Query("billing_mode"))
var requestType *int16
var stream *bool
@@ -312,6 +315,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
RequestType: requestType,
Stream: stream,
BillingType: billingType,
BillingMode: billingMode,
StartTime: &startTime,
EndTime: &endTime,
}

View File

@@ -577,6 +577,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
BillingMode: l.BillingMode,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
@@ -604,6 +605,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l),
UpstreamModel: l.UpstreamModel,
ChannelID: l.ChannelID,
ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),

View File

@@ -390,6 +390,9 @@ type UsageLog struct {
// Cache TTL Override 标记
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
// BillingMode 计费模式token/image
BillingMode *string `json:"billing_mode,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
@@ -406,6 +409,13 @@ type AdminUsageLog struct {
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ChannelID 渠道 ID
ChannelID *int64 `json:"channel_id,omitempty"`
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
// BillingTier 计费层级标签per_request/image 模式)
BillingTier *string `json:"billing_tier,omitempty"`
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`

View File

@@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
@@ -292,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -478,6 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -514,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -660,6 +664,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
parsedReq.OnUpstreamAccepted = queueRelease
// ===== 用户消息串行队列 END =====
// 应用渠道模型映射到请求
if channelMapping.Mapped {
parsedReq.Model = channelMapping.MappedModel
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
@@ -810,6 +821,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),

View File

@@ -80,6 +80,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
@@ -154,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
@@ -203,7 +206,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -255,6 +262,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),

View File

@@ -80,6 +80,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
@@ -159,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
@@ -208,7 +211,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -261,6 +268,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),

View File

@@ -161,6 +161,8 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // digestStore
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。

View File

@@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
modelName = channelMapping.MappedModel
}
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
@@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),

View File

@@ -30,6 +30,7 @@ type AdminHandlers struct {
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler
Channel *admin.ChannelHandler
}
// Handlers contains all HTTP handlers

View File

@@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
@@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardStart := time.Now()
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@@ -257,16 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),

View File

@@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
@@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
// 应用渠道模型映射到请求体
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
@@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
// 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
@@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
@@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
@@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
},
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",

View File

@@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
)
}

View File

@@ -30,6 +30,8 @@ import (
)
// SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除不集成渠道Channel功能。
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
@@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
var lastFailoverHeaders http.Header
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
if err != nil {
reqLog.Warn("sora.account_select_failed",
zap.Error(err),

View File

@@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil, // digestStore
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}

View File

@@ -33,6 +33,7 @@ func ProvideAdminHandlers(
tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -59,6 +60,7 @@ func ProvideAdminHandlers(
TLSFingerprintProfile: tlsFingerprintProfileHandler,
APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler,
Channel: channelHandler,
}
}
@@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
admin.NewTLSFingerprintProfileHandler,
admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
admin.NewChannelHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// ClaudeError Claude 错误响应

View File

@@ -149,13 +149,31 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
}
// GeminiTokenDetail Gemini token 详情(按模态分类)
type GeminiTokenDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens按输出价格计费
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens按输出价格计费
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
for _, d := range m.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
return d.TokenCount
}
}
return 0
}
// GeminiGroundingMetadata Gemini grounding 元数据Web Search

View File

@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
}
// 生成响应 ID

View File

@@ -32,9 +32,10 @@ type StreamingProcessor struct {
groundingChunks []GeminiGroundingChunk
// 累计 usage
inputTokens int
outputTokens int
cacheReadTokens int
inputTokens int
outputTokens int
cacheReadTokens int
imageOutputTokens int
}
// NewStreamingProcessor 创建流式响应处理器
@@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
}
// 处理 parts
@@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
}
if !p.messageStartSent {
@@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
}
responseID := v1Resp.ResponseID
@@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
}
deltaEvent := map[string]any{

View File

@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
// Additional filter conditions
UserID int64 // filter by user_id (>0 to enable)
APIKeyID int64 // filter by api_key_id (>0 to enable)
AccountID int64 // filter by account_id (>0 to enable)
RequestType *int16 // filter by request_type (non-nil to enable)
Stream *bool // filter by stream flag (non-nil to enable)
BillingType *int8 // filter by billing_type (non-nil to enable)
}
// APIKeyUsageTrendPoint represents API key usage trend data point
@@ -230,6 +237,7 @@ type UsageLogFilters struct {
RequestType *int16
Stream *bool
BillingType *int8
BillingMode string
StartTime *time.Time
EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.

View File

@@ -0,0 +1,461 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type channelRepository struct {
db *sql.DB
}
// NewChannelRepository 创建渠道数据访问实例
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
return &channelRepository{db: db}
}
// runInTx 在事务中执行 fn成功 commit失败 rollback。
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("insert channel: %w", err)
}
// 设置分组关联
if len(channel.GroupIDs) > 0 {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 设置模型定价
if len(channel.ModelPricing) > 0 {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
return nil, err
}
ch.GroupIDs = groupIDs
pricing, err := r.ListModelPricing(ctx, id)
if err != nil {
return nil, err
}
ch.ModelPricing = pricing
return ch, nil
}
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $7`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("update channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
// 更新分组关联
if channel.GroupIDs != nil {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 更新模型定价
if channel.ModelPricing != nil {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
return nil
}
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
where := []string{"1=1"}
args := []any{}
argIdx := 1
if status != "" {
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
args = append(args, status)
argIdx++
}
if search != "" {
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
args = append(args, "%"+escapeLike(search)+"%")
argIdx++
}
whereClause := strings.Join(where, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count channels: %w", err)
}
pageSize := params.Limit() // 约束在 [1, 100]
page := params.Page
if page < 1 {
page = 1
}
offset := (page - 1) * pageSize
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
if err != nil {
return nil, nil, fmt.Errorf("query channels: %w", err)
}
defer func() { _ = rows.Close() }()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate channels: %w", err)
}
// 批量加载分组 ID 和模型定价(避免 N+1
if len(channelIDs) > 0 {
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
}
pages := 0
if total > 0 {
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
}
paginationResult := &pagination.PaginationResult{
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
}
return channels, paginationResult, nil
}
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
}
defer func() { _ = rows.Close() }()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate channels: %w", err)
}
if len(channelIDs) == 0 {
return channels, nil
}
// 批量加载分组 ID
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, err
}
// 批量加载模型定价
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
return channels, nil
}
// --- 批量加载辅助方法 ---
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT channel_id, group_id FROM channel_groups
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load group ids: %w", err)
}
defer func() { _ = rows.Close() }()
groupMap := make(map[int64][]int64, len(channelIDs))
for rows.Next() {
var channelID, groupID int64
if err := rows.Scan(&channelID, &groupID); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
groupMap[channelID] = append(groupMap[channelID], groupID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return groupMap, nil
}
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
).Scan(&exists)
return exists, err
}
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
).Scan(&exists)
return exists, err
}
// --- 分组关联 ---
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("get group ids: %w", err)
}
defer func() { _ = rows.Close() }()
var ids []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return ids, nil
}
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
}
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
var channelID int64
err := r.db.QueryRowContext(ctx,
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
).Scan(&channelID)
if err == sql.ErrNoRows {
return 0, nil
}
return channelID, err
}
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
pq.Array(groupIDs), channelID,
)
if err != nil {
return nil, fmt.Errorf("get groups in other channels: %w", err)
}
defer func() { _ = rows.Close() }()
var conflicting []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan conflicting group id: %w", err)
}
conflicting = append(conflicting, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
}
return conflicting, nil
}
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
// 格式:{"platform": {"src": "dst"}, ...}
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal model_mapping: %w", err)
}
return data, nil
}
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
func unmarshalModelMapping(data []byte) map[string]map[string]string {
if len(data) == 0 {
return nil
}
var m map[string]map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
return make(map[int64]string), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
pq.Array(groupIDs),
)
if err != nil {
return nil, fmt.Errorf("get group platforms: %w", err)
}
defer rows.Close() //nolint:errcheck
result := make(map[int64]string, len(groupIDs))
for rows.Next() {
var id int64
var platform string
if err := rows.Scan(&id, &platform); err != nil {
return nil, fmt.Errorf("scan group platform: %w", err)
}
result[id] = platform
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group platforms: %w", err)
}
return result, nil
}

View File

@@ -0,0 +1,291 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 模型定价 ---
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("list model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
result, pricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
if len(pricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
if err != nil {
return nil, err
}
for i := range result {
result[i].Intervals = intervalMap[result[i].ID]
}
}
return result, nil
}
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
return createModelPricingExec(ctx, r.db, pricing)
}
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
result, err := r.db.ExecContext(ctx,
`UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
WHERE id = $10`,
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
)
if err != nil {
return fmt.Errorf("update model pricing: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
}
return nil
}
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete model pricing: %w", err)
}
return nil
}
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
})
}
// --- 批量加载辅助方法 ---
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
// 按 channelID 分组
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
for _, p := range allPricing {
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
}
// 批量加载所有区间
if len(allPricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for chID := range pricingMap {
for i := range pricingMap[chID] {
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
}
}
}
return pricingMap, nil
}
// batchLoadIntervals 批量加载多个定价条目的区间
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load intervals: %w", err)
}
defer func() { _ = rows.Close() }()
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan interval: %w", err)
}
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate intervals: %w", err)
}
return intervalMap, nil
}
// --- 共享 scan 辅助 ---
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
var result []service.ChannelModelPricing
var pricingIDs []int64
for rows.Next() {
var p service.ChannelModelPricing
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingIDs = append(pricingIDs, p.ID)
result = append(result, p)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
}
return result, pricingIDs, nil
}
// --- 事务内辅助方法 ---
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
type dbExec interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old group associations: %w", err)
}
if len(groupIDs) == 0 {
return nil
}
_, err := exec.ExecContext(ctx,
`INSERT INTO channel_groups (channel_id, group_id)
SELECT $1, unnest($2::bigint[])`,
channelID, pq.Array(groupIDs),
)
if err != nil {
return fmt.Errorf("insert group associations: %w", err)
}
return nil
}
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
if platform == "" {
platform = "anthropic"
}
err = exec.QueryRowContext(ctx,
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
pricing.ChannelID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
for i := range pricing.Intervals {
pricing.Intervals[i].PricingID = pricing.ID
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
return err
}
}
return nil
}
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
return exec.QueryRowContext(ctx,
`INSERT INTO channel_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old model pricing: %w", err)
}
for i := range pricingList {
pricingList[i].ChannelID = channelID
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
}
return nil
}
// isUniqueViolation 检查 pq 唯一约束违反错误
func isUniqueViolation(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr != nil {
return pqErr.Code == "23505"
}
return false
}
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
func escapeLike(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}

View File

@@ -0,0 +1,227 @@
//go:build unit
package repository
import (
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
// --- marshalModelMapping ---
func TestMarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input map[string]map[string]string
wantJSON string // expected JSON output (exact match)
}{
{
name: "empty map",
input: map[string]map[string]string{},
wantJSON: "{}",
},
{
name: "nil map",
input: nil,
wantJSON: "{}",
},
{
name: "populated map",
input: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
},
{
name: "nested values",
input: map[string]map[string]string{
"openai": {"*": "gpt-5.4"},
"anthropic": {"claude-old": "claude-new"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := marshalModelMapping(tt.input)
require.NoError(t, err)
if tt.wantJSON != "" {
require.Equal(t, []byte(tt.wantJSON), result)
} else {
// round-trip: unmarshal and compare with input
var parsed map[string]map[string]string
require.NoError(t, json.Unmarshal(result, &parsed))
require.Equal(t, tt.input, parsed)
}
})
}
}
// --- unmarshalModelMapping ---
func TestUnmarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input []byte
wantNil bool
want map[string]map[string]string
}{
{
name: "nil data",
input: nil,
wantNil: true,
},
{
name: "empty data",
input: []byte{},
wantNil: true,
},
{
name: "invalid JSON",
input: []byte("not-json"),
wantNil: true,
},
{
name: "type error - number",
input: []byte("42"),
wantNil: true,
},
{
name: "type error - array",
input: []byte("[1,2,3]"),
wantNil: true,
},
{
name: "valid JSON",
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
want: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
"anthropic": {"old": "new"},
},
},
{
name: "empty object",
input: []byte("{}"),
want: map[string]map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := unmarshalModelMapping(tt.input)
if tt.wantNil {
require.Nil(t, result)
} else {
require.NotNil(t, result)
require.Equal(t, tt.want, result)
}
})
}
}
// --- escapeLike ---
func TestEscapeLike(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "no special chars",
input: "hello",
want: "hello",
},
{
name: "backslash",
input: `a\b`,
want: `a\\b`,
},
{
name: "percent",
input: "50%",
want: `50\%`,
},
{
name: "underscore",
input: "a_b",
want: `a\_b`,
},
{
name: "all special chars",
input: `a\b%c_d`,
want: `a\\b\%c\_d`,
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "consecutive special chars",
input: "%_%",
want: `\%\_\%`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, escapeLike(tt.input))
})
}
}
// --- isUniqueViolation ---
func TestIsUniqueViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "unique violation code 23505",
err: &pq.Error{Code: "23505"},
want: true,
},
{
name: "different pq error code",
err: &pq.Error{Code: "23503"},
want: false,
},
{
name: "non-pq error",
err: errors.New("some generic error"),
want: false,
},
{
name: "typed nil pq.Error",
err: func() error {
var pqErr *pq.Error
return pqErr
}(),
want: false,
},
{
name: "bare nil",
err: nil,
want: false,
},
{
name: "wrapped pq error with 23505",
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, isUniqueViolation(tt.err))
})
}
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost
"numeric", // output_cost
"numeric", // cache_creation_cost
@@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // inbound_endpoint
"text", // upstream_endpoint
"boolean", // cache_ttl_overridden
"bigint", // channel_id
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"timestamptz", // created_at
}
@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*39)
args := make([]any, 0, len(keys)*47)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
)
SELECT
@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*40)
args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
)
SELECT
@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
channelID := nullInt64(log.ChannelID)
modelMappingChain := nullString(log.ModelMappingChain)
billingTier := nullString(log.BillingTier)
billingMode := nullString(log.BillingMode)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
@@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
@@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden,
channelID,
modelMappingChain,
billingTier,
billingMode,
createdAt,
},
}
@@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
@@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
@@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64
outputCost float64
cacheCreationCost float64
@@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool
channelID sql.NullInt64
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
createdAt time.Time
)
@@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens,
&cacheCreation5m,
&cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost,
&outputCost,
&cacheCreationCost,
@@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden,
&channelID,
&modelMappingChain,
&billingTier,
&billingMode,
&createdAt,
); err != nil {
return nil, err
@@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost,
OutputCost: outputCost,
CacheCreationCost: cacheCreationCost,
@@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
if channelID.Valid {
value := channelID.Int64
log.ChannelID = &value
}
if modelMappingChain.Valid {
log.ModelMappingChain = &modelMappingChain.String
}
if billingTier.Valid {
log.BillingTier = &billingTier.String
}
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
return log, nil
}

View File

@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
@@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
@@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
@@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
@@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
@@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
@@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
@@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)

View File

@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
// Cache implementations
NewGatewayCache,

View File

@@ -87,6 +87,9 @@ func RegisterAdminRoutes(
// 定时测试计划
registerScheduledTestRoutes(admin, h)
// 渠道管理
registerChannelRoutes(admin, h)
}
}
@@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
}
}
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}

View File

@@ -56,6 +56,7 @@ type ModelPricing struct {
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
}
const (
@@ -94,16 +95,19 @@ type UsageTokens struct {
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
ImageOutputTokens int
}
// CostBreakdown 费用明细
type CostBreakdown struct {
InputCost float64
OutputCost float64
ImageOutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
}
// BillingService 计费服务
@@ -357,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
}), nil
}
}
@@ -371,81 +376,252 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
return nil, fmt.Errorf("pricing not found for model: %s", model)
}
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
}
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
// 仅覆盖渠道中非 nil 的价格字段nil 字段使用默认定价
func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
pricing, err := s.GetModelPricing(model)
if err != nil {
return nil, err
}
if channelPricing == nil {
return pricing, nil
}
if channelPricing.InputPrice != nil {
pricing.InputPricePerToken = *channelPricing.InputPrice
pricing.InputPricePerTokenPriority = *channelPricing.InputPrice
}
if channelPricing.OutputPrice != nil {
pricing.OutputPricePerToken = *channelPricing.OutputPrice
pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice
}
if channelPricing.CacheWritePrice != nil {
pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice
pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice
pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice
}
if channelPricing.CacheReadPrice != nil {
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
}
if channelPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
}
return pricing, nil
}
breakdown := &CostBreakdown{}
inputPricePerToken := pricing.InputPricePerToken
outputPricePerToken := pricing.OutputPricePerToken
cacheReadPricePerToken := pricing.CacheReadPricePerToken
// --- 统一计费入口 ---
// CostInput 统一计费输入
type CostInput struct {
Ctx context.Context
Model string
GroupID *int64 // 用于渠道定价查找
Tokens UsageTokens
RequestCount int // 按次计费时使用
SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
RateMultiplier float64
ServiceTier string // "priority","flex","" 等
Resolver *ModelPricingResolver // 定价解析器
Resolved *ResolvedPricing // 可选:预解析的定价结果(避免重复 Resolve 调用)
}
// CalculateCostUnified 统一计费入口,支持三种计费模式。
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) {
if input.Resolver == nil {
// 无 Resolver回退到旧路径
return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil)
}
// 优先使用预解析结果,避免重复 Resolve 调用
resolved := input.Resolved
if resolved == nil {
resolved = input.Resolver.Resolve(input.Ctx, PricingInput{
Model: input.Model,
GroupID: input.GroupID,
})
}
if input.RateMultiplier <= 0 {
input.RateMultiplier = 1.0
}
var breakdown *CostBreakdown
var err error
switch resolved.Mode {
case BillingModePerRequest, BillingModeImage:
breakdown, err = s.calculatePerRequestCost(resolved, input)
default: // BillingModeToken
breakdown, err = s.calculateTokenCost(resolved, input)
}
if err == nil && breakdown != nil {
breakdown.BillingMode = string(resolved.Mode)
if breakdown.BillingMode == "" {
breakdown.BillingMode = string(BillingModeToken)
}
}
return breakdown, err
}
// calculateTokenCost 按 token 区间计费
func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
if pricing == nil {
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
}
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
// 长上下文定价仅在无区间定价时应用(区间定价已包含上下文分层)
applyLongCtx := len(resolved.Intervals) == 0
return s.computeTokenBreakdown(pricing, input.Tokens, input.RateMultiplier, input.ServiceTier, applyLongCtx), nil
}
// computeTokenBreakdown 是 token 计费的核心逻辑,由 calculateTokenCost 和 calculateCostInternal 共用。
// applyLongCtx 控制是否检查长上下文定价(区间定价已自含上下文分层,不需要额外应用)。
func (s *BillingService) computeTokenBreakdown(
pricing *ModelPricing, tokens UsageTokens,
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
inputPrice := pricing.InputPricePerToken
outputPrice := pricing.OutputPricePerToken
cacheReadPrice := pricing.CacheReadPricePerToken
tierMultiplier := 1.0
if usePriorityServiceTierPricing(serviceTier, pricing) {
if pricing.InputPricePerTokenPriority > 0 {
inputPricePerToken = pricing.InputPricePerTokenPriority
inputPrice = pricing.InputPricePerTokenPriority
}
if pricing.OutputPricePerTokenPriority > 0 {
outputPricePerToken = pricing.OutputPricePerTokenPriority
outputPrice = pricing.OutputPricePerTokenPriority
}
if pricing.CacheReadPricePerTokenPriority > 0 {
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
cacheReadPrice = pricing.CacheReadPricePerTokenPriority
}
} else {
tierMultiplier = serviceTierCostMultiplier(serviceTier)
}
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPrice *= pricing.LongContextInputMultiplier
outputPrice *= pricing.LongContextOutputMultiplier
}
// 计算输入token费用使用per-token价格
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
bd := &CostBreakdown{}
bd.InputCost = float64(tokens.InputTokens) * inputPrice
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
// 分离图片输出 token 与文本输出 token
textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
bd.OutputCost = float64(textOutputTokens) * outputPrice
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型5分钟/1小时缓存价格为 per-token
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
// 图片输出 token 费用(独立费率)
if tokens.ImageOutputTokens > 0 {
imgPrice := pricing.ImageOutputPricePerToken
if imgPrice == 0 {
imgPrice = outputPrice // 回退到常规输出价格
}
} else {
// 标准缓存创建价格per-token
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice
}
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
// 缓存创建费用
bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens)
bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice
if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier
bd.InputCost *= tierMultiplier
bd.OutputCost *= tierMultiplier
bd.ImageOutputCost *= tierMultiplier
bd.CacheCreationCost *= tierMultiplier
bd.CacheReadCost *= tierMultiplier
}
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
bd.TotalCost = bd.InputCost + bd.OutputCost + bd.ImageOutputCost +
bd.CacheCreationCost + bd.CacheReadCost
bd.ActualCost = bd.TotalCost * rateMultiplier
// 应用倍率计算实际费用
if rateMultiplier <= 0 {
rateMultiplier = 1.0
return bd
}
// computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。
func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 {
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
}
return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
return breakdown, nil
// calculatePerRequestCost 按次/图片计费
func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
count := input.RequestCount
if count <= 0 {
count = 1
}
var unitPrice float64
if input.SizeTier != "" {
unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier)
}
if unitPrice == 0 {
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext)
}
// 回退到默认按次价格
if unitPrice == 0 {
unitPrice = resolved.DefaultPerRequestPrice
}
totalCost := unitPrice * float64(count)
actualCost := totalCost * input.RateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}, nil
}
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil)
}
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil)
}
func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
var pricing *ModelPricing
var err error
if channelPricing != nil {
pricing, err = s.GetModelPricingWithChannel(model, channelPricing)
} else {
pricing, err = s.GetModelPricing(model)
}
if err != nil {
return nil, err
}
// 旧路径始终检查长上下文定价(无区间定价概念)
return s.computeTokenBreakdown(pricing, tokens, rateMultiplier, serviceTier, true), nil
}
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
@@ -541,6 +717,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
ImageOutputTokens: tokens.ImageOutputTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
@@ -561,6 +738,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
ImageOutputCost: inRangeCost.ImageOutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
@@ -662,8 +840,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
TotalCost: totalCost,
ActualCost: actualCost,
BillingMode: string(BillingModeImage),
}
}

View File

@@ -0,0 +1,277 @@
package service
import (
"fmt"
"sort"
"strings"
"time"
)
// BillingMode 计费模式
type BillingMode string
const (
BillingModeToken BillingMode = "token" // 按 token 区间计费
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
)
// IsValid 检查 BillingMode 是否为合法值
func (m BillingMode) IsValid() bool {
switch m {
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
return true
}
return false
}
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
)
// Channel 渠道实体
type Channel struct {
ID int64
Name string
Description string
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time
UpdatedAt time.Time
// 关联的分组 ID 列表
GroupIDs []int64
// 模型定价列表(每条含 Platform 字段)
ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
}
// ChannelModelPricing 渠道模型定价条目
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
}
// PricingInterval 定价区间token 区间 / 按次分层 / 图片分辨率分层)
type PricingInterval struct {
ID int64
PricingID int64
MinTokens int // 区间下界(含)
MaxTokens *int // 区间上界不含nil = 无上限
TierLabel string // 层级标签(按次/图片模式1K, 2K, 4K, HD 等)
InputPrice *float64 // token 模式:每 token 输入价
OutputPrice *float64 // token 模式:每 token 输出价
CacheWritePrice *float64 // token 模式:缓存写入价
CacheReadPrice *float64 // token 模式:缓存读取价
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
SortOrder int
CreatedAt time.Time
UpdatedAt time.Time
}
// IsActive 判断渠道是否启用
func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
modelLower := strings.ToLower(model)
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
return nil
}
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
// 区间为左开右闭 (min, max]min 不含max 包含。
// 第一个区间 min=0 时0 token 不匹配任何区间(回退到默认价格)。
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
for i := range intervals {
iv := &intervals[i]
if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) {
return iv
}
}
return nil
}
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
return FindMatchingInterval(p.Intervals, totalTokens)
}
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
labelLower := strings.ToLower(label)
for i := range p.Intervals {
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
return &p.Intervals[i]
}
}
return nil
}
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
func (p ChannelModelPricing) Clone() ChannelModelPricing {
cp := p
if p.Models != nil {
cp.Models = make([]string, len(p.Models))
copy(cp.Models, p.Models)
}
if p.Intervals != nil {
cp.Intervals = make([]PricingInterval, len(p.Intervals))
copy(cp.Intervals, p.Intervals)
}
return cp
}
// Clone 返回 Channel 的深拷贝
func (c *Channel) Clone() *Channel {
if c == nil {
return nil
}
cp := *c
if c.GroupIDs != nil {
cp.GroupIDs = make([]int64, len(c.GroupIDs))
copy(cp.GroupIDs, c.GroupIDs)
}
if c.ModelPricing != nil {
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
for i := range c.ModelPricing {
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
}
}
if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
for platform, mapping := range c.ModelMapping {
inner := make(map[string]string, len(mapping))
for k, v := range mapping {
inner[k] = v
}
cp.ModelMapping[platform] = inner
}
}
return &cp
}
// ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);
// 无界区间MaxTokens=nil必须是最后一个。间隙允许回退默认价格
func ValidateIntervals(intervals []PricingInterval) error {
if len(intervals) == 0 {
return nil
}
sorted := make([]PricingInterval, len(intervals))
copy(sorted, intervals)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].MinTokens < sorted[j].MinTokens
})
for i := range sorted {
if err := validateSingleInterval(&sorted[i], i); err != nil {
return err
}
}
return validateIntervalOverlap(sorted)
}
// validateSingleInterval 校验单个区间的字段合法性
func validateSingleInterval(iv *PricingInterval, idx int) error {
if iv.MinTokens < 0 {
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
}
if iv.MaxTokens != nil {
if *iv.MaxTokens <= 0 {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
}
if *iv.MaxTokens <= iv.MinTokens {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
idx+1, *iv.MaxTokens, iv.MinTokens)
}
}
return validateIntervalPrices(iv, idx)
}
// validateIntervalPrices 校验区间内所有价格字段 >= 0
func validateIntervalPrices(iv *PricingInterval, idx int) error {
prices := []struct {
name string
val *float64
}{
{"input_price", iv.InputPrice},
{"output_price", iv.OutputPrice},
{"cache_write_price", iv.CacheWritePrice},
{"cache_read_price", iv.CacheReadPrice},
{"per_request_price", iv.PerRequestPrice},
}
for _, p := range prices {
if p.val != nil && *p.val < 0 {
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
}
}
return nil
}
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
func validateIntervalOverlap(sorted []PricingInterval) error {
for i, iv := range sorted {
// 无界区间必须是最后一个
if iv.MaxTokens == nil && i < len(sorted)-1 {
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
i+1)
}
if i == 0 {
continue
}
prev := sorted[i-1]
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
// (min, max] 语义prev 覆盖 (prev.Min, prev.Max]cur 覆盖 (cur.Min, cur.Max]
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
}
}
return nil
}
func formatMaxTokensLabel(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}

View File

@@ -0,0 +1,857 @@
package service
import (
"context"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/sync/singleflight"
)
var (
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
ErrGroupAlreadyInChannel = infraerrors.Conflict(
"GROUP_ALREADY_IN_CHANNEL",
"one or more groups already belong to another channel",
)
)
// ChannelRepository 渠道数据访问接口
type ChannelRepository interface {
Create(ctx context.Context, channel *Channel) error
GetByID(ctx context.Context, id int64) (*Channel, error)
Update(ctx context.Context, channel *Channel) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
ListAll(ctx context.Context) ([]Channel, error)
ExistsByName(ctx context.Context, name string) (bool, error)
ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error)
// 分组关联
GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error)
SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
// 分组平台查询
GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error)
// 模型定价
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
DeleteModelPricing(ctx context.Context, id int64) error
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
}
// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突)
type channelModelKey struct {
groupID int64
platform string // 平台标识
model string // lowercase
}
// channelGroupPlatformKey 通配符定价缓存键
type channelGroupPlatformKey struct {
groupID int64
platform string
}
// wildcardPricingEntry 通配符定价条目
type wildcardPricingEntry struct {
prefix string
pricing *ChannelModelPricing
}
// wildcardMappingEntry 通配符映射条目
type wildcardMappingEntry struct {
prefix string
target string
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径CRUD 操作)
byID map[int64]*Channel
loadedAt time.Time
}
// ChannelMappingResult 渠道映射查找结果
type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID0 = 无渠道关联)
Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped"
}
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
// reqModel: 客户端请求的原始模型名。
// upstreamModel: 上游实际使用的模型名ForwardResult.UpstreamModel
// 返回空字符串表示无映射。
func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel string) string {
if !r.Mapped {
if upstreamModel != "" && upstreamModel != reqModel {
return reqModel + "→" + upstreamModel
}
return ""
}
if upstreamModel != "" && upstreamModel != r.MappedModel {
return reqModel + "→" + r.MappedModel + "→" + upstreamModel
}
return reqModel + "→" + r.MappedModel
}
// ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
channelMappedModel := reqModel
if r.Mapped {
channelMappedModel = r.MappedModel
}
return ChannelUsageFields{
ChannelID: r.ChannelID,
OriginalModel: reqModel,
ChannelMappedModel: channelMappedModel,
BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
}
}
const (
channelCacheTTL = 10 * time.Minute
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second
)
// ChannelService 渠道管理服务
type ChannelService struct {
repo ChannelRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
cache atomic.Value // *channelCache
cacheSF singleflight.Group
}
// NewChannelService 创建渠道服务实例
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
s := &ChannelService{
repo: repo,
authCacheInvalidator: authCacheInvalidator,
}
return s
}
// loadCache 加载或返回缓存的渠道数据
func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil {
if time.Since(cached.loadedAt) < channelCacheTTL {
return cached, nil
}
}
result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) {
// 双重检查
if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil {
if time.Since(cached.loadedAt) < channelCacheTTL {
return cached, nil
}
}
return s.buildCache(ctx)
})
if err != nil {
return nil, err
}
cache, ok := result.(*channelCache)
if !ok {
return nil, fmt.Errorf("unexpected cache type")
}
return cache, nil
}
// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化)
func newEmptyChannelCache() *channelCache {
return &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
}
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
// 缓存 key 使用定价条目的原始平台pricing.Platform而非分组平台
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
if !isPlatformPricingMatch(platform, pricing.Platform) {
continue // 跳过非本平台的定价
}
// 使用定价条目的原始平台作为缓存 key防止跨平台同名模型冲突
pricingPlatform := pricing.Platform
gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform}
for _, model := range pricing.Models {
if strings.HasSuffix(model, "*") {
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
prefix: prefix,
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
}
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。
// 缓存 key 使用映射条目的原始平台mappingPlatform避免跨平台同名映射覆盖。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
if !ok {
continue
}
// 使用映射条目的原始平台作为缓存 key防止跨平台同名映射冲突
gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform}
for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") {
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
prefix: prefix,
target: dst,
})
} else {
key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
}
}
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
channels, err := s.repo.ListAll(dbCtx)
if err != nil {
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
}
// 收集所有 groupID批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
return nil, fmt.Errorf("get group platforms: %w", err)
}
}
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
cache.loadedAt = time.Now()
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
expandPricingToCache(cache, ch, gid, platform)
expandMappingToCache(cache, ch, gid, platform)
}
}
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
}
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
// antigravity 平台同时服务 Claudeanthropic和 Geminigemini模型
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
if groupPlatform == pricingPlatform {
return true
}
if groupPlatform == PlatformAntigravity {
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
}
return false
}
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
func matchingPlatforms(groupPlatform string) []string {
if groupPlatform == PlatformAntigravity {
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
}
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
s.cache.Store((*channelCache)(nil))
s.cacheSF.Forget("channel_cache")
// 主动重建缓存,确保 CRUD 后立即生效
if _, err := s.buildCache(context.Background()); err != nil {
slog.Warn("failed to rebuild channel cache after invalidation", "error", err)
}
}
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.pricing
}
}
return nil
}
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardMappingByGP[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.target
}
}
return ""
}
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台antigravity → anthropic → gemini
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if pricing, ok := cache.pricingByGroupModel[key]; ok {
return pricing
}
}
// 精确查找全部失败,依次尝试通配符匹配
for _, p := range matchingPlatforms(groupPlatform) {
if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil {
return pricing
}
}
return nil
}
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
return mapped
}
}
for _, p := range matchingPlatforms(groupPlatform) {
if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" {
return mapped
}
}
return ""
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx)
if err != nil {
return nil, err
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil, nil
}
return ch.Clone(), nil
}
// channelLookup 热路径公共查找结果
type channelLookup struct {
cache *channelCache
channel *Channel
platform string
}
// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。
// 返回 nil 且 err==nil 表示分组无活跃渠道err!=nil 表示缓存加载失败。
func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) {
cache, err := s.loadCache(ctx)
if err != nil {
return nil, err
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil, nil
}
return &channelLookup{
cache: cache,
channel: ch,
platform: cache.groupPlatform[groupID],
}, nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
// antigravity 分组依次尝试所有匹配平台antigravity → anthropic → gemini
// 确保跨平台同名模型各自独立匹配。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
return nil
}
if lk == nil {
return nil
}
modelLower := strings.ToLower(model)
pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower)
if pricing == nil {
return nil
}
cp := pricing.Clone()
return &cp
}
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache for mapping", "group_id", groupID, "error", err)
}
if lk == nil {
return ChannelMappingResult{MappedModel: model}
}
return resolveMapping(lk, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制。
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID)
if lk == nil {
return false
}
return checkRestricted(lk, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 返回映射结果。模型限制检查已移至调度阶段GatewayService.checkChannelPricingRestriction
// restricted 始终返回 false保留签名兼容性。
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if groupID == nil {
return ChannelMappingResult{MappedModel: model}, false
}
lk, _ := s.lookupGroupChannel(ctx, *groupID)
if lk == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return resolveMapping(lk, *groupID, model), false
}
// resolveMapping 基于已查找的渠道信息解析模型映射。
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
result := ChannelMappingResult{
MappedModel: model,
ChannelID: lk.channel.ID,
BillingModelSource: lk.channel.BillingModelSource,
}
if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceChannelMapped
}
modelLower := strings.ToLower(model)
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
result.MappedModel = mapped
result.Mapped = true
}
return result
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
// antigravity 分组依次尝试所有匹配平台的定价列表。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
}
modelLower := strings.ToLower(model)
// 使用与查找定价相同的跨平台逻辑
if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil {
return false
}
return true
}
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
func ReplaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 {
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
}
// --- CRUD ---
// Create 创建渠道
func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) {
exists, err := s.repo.ExistsByName(ctx, input.Name)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
}
// 检查分组冲突
if len(input.GroupIDs) > 0 {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
}
channel := &Channel{
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
if err := s.repo.Create(ctx, channel); err != nil {
return nil, fmt.Errorf("create channel: %w", err)
}
s.invalidateCache()
return s.repo.GetByID(ctx, channel.ID)
}
// GetByID 获取渠道详情
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
return s.repo.GetByID(ctx, id)
}
// Update 更新渠道
func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) {
channel, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
// 检查分组冲突
if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
// 先获取旧分组Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
}
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
// 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil {
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
for _, gid := range oldGroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
return s.repo.GetByID(ctx, id)
}
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
}
if err := s.repo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete channel: %w", err)
}
s.invalidateCache()
if s.authCacheInvalidator != nil {
for _, gid := range groupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
return nil
}
// List 获取渠道列表
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
return s.repo.List(ctx, params, status, search)
}
// modelEntry 表示一个模型模式条目(用于冲突检测)
type modelEntry struct {
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4"
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
wildcard bool
}
// conflictsBetween 检查两个模型模式是否冲突
func conflictsBetween(a, b modelEntry) bool {
switch {
case !a.wildcard && !b.wildcard:
return a.prefix == b.prefix
case a.wildcard && !b.wildcard:
return strings.HasPrefix(b.prefix, a.prefix)
case !a.wildcard && b.wildcard:
return strings.HasPrefix(a.prefix, b.prefix)
default:
return strings.HasPrefix(a.prefix, b.prefix) ||
strings.HasPrefix(b.prefix, a.prefix)
}
}
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
lower := strings.ToLower(pattern)
isWild := strings.HasSuffix(lower, "*")
prefix := lower
if isWild {
prefix = strings.TrimSuffix(lower, "*")
}
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
byPlatform := make(map[string][]modelEntry)
for _, p := range pricingList {
for _, model := range p.Models {
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
}
}
for platform, entries := range byPlatform {
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
return err
}
}
return nil
}
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
for platform, platformMapping := range mapping {
entries := make([]modelEntry, 0, len(platformMapping))
for src := range platformMapping {
entries = append(entries, toModelEntry(src))
}
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
return err
}
}
return nil
}
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
for _, pricing := range pricingList {
if err := ValidateIntervals(pricing.Intervals); err != nil {
return infraerrors.BadRequest(
"INVALID_PRICING_INTERVALS",
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
pricing.Platform, pricing.Models, err),
)
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if conflictsBetween(entries[i], entries[j]) {
return infraerrors.BadRequest(errCode,
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
label, entries[i].pattern, entries[j].pattern, platform))
}
}
}
return nil
}
// --- Input types ---
// CreateChannelInput 创建渠道输入
type CreateChannelInput struct {
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
}
// UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct {
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,435 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetModelPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
},
}
tests := []struct {
name string
model string
wantID int64
wantNil bool
}{
{"exact match", "claude-sonnet-4", 1, false},
{"case insensitive", "Claude-Sonnet-4", 1, false},
{"not found", "gemini-3.1-pro", 0, true},
{"wildcard pattern not matched", "claude-opus-4-20250514", 0, true},
{"per_request model", "gpt-5.1", 3, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ch.GetModelPricing(tt.model)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.Equal(t, tt.wantID, result.ID)
})
}
}
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
},
}
result := ch.GetModelPricing("claude-sonnet-4")
require.NotNil(t, result)
// Modify the returned copy's slice — original should be unchanged
result.Models = append(result.Models, "hacked")
// Original should be unchanged
require.Equal(t, 1, len(ch.ModelPricing[0].Models))
}
func TestGetModelPricing_EmptyPricing(t *testing.T) {
ch := &Channel{ModelPricing: nil}
require.Nil(t, ch.GetModelPricing("any-model"))
ch2 := &Channel{ModelPricing: []ChannelModelPricing{}}
require.Nil(t, ch2.GetModelPricing("any-model"))
}
func TestGetIntervalForContext(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
}
tests := []struct {
name string
tokens int
wantPrice *float64
wantNil bool
}{
{"first interval", 50000, testPtrFloat64(1e-6), false},
// (min, max] — 128000 在第一个区间的 max包含所以匹配第一个
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
// 128001 > 128000匹配第二个区间
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
// (0, max] — 0 不匹配任何区间(左开)
{"zero tokens: no match", 0, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := p.GetIntervalForContext(tt.tokens)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12)
})
}
}
func TestGetIntervalForContext_NoMatch(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
},
}
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
}
func TestGetIntervalForContext_Empty(t *testing.T) {
p := &ChannelModelPricing{Intervals: nil}
require.Nil(t, p.GetIntervalForContext(1000))
}
func TestGetTierByLabel(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
},
}
tests := []struct {
name string
label string
wantNil bool
want float64
}{
{"exact match", "1K", false, 0.04},
{"case insensitive", "hd", false, 0.12},
{"not found", "4K", true, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := p.GetTierByLabel(tt.label)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12)
})
}
}
func TestGetTierByLabel_Empty(t *testing.T) {
p := &ChannelModelPricing{Intervals: nil}
require.Nil(t, p.GetTierByLabel("1K"))
}
func TestChannelClone(t *testing.T) {
original := &Channel{
ID: 1,
Name: "test",
GroupIDs: []int64{10, 20},
ModelPricing: []ChannelModelPricing{
{
ID: 100,
Models: []string{"model-a"},
InputPrice: testPtrFloat64(5e-6),
},
},
}
cloned := original.Clone()
require.NotNil(t, cloned)
require.Equal(t, original.ID, cloned.ID)
require.Equal(t, original.Name, cloned.Name)
// Modify clone slices — original should not change
cloned.GroupIDs[0] = 999
require.Equal(t, int64(10), original.GroupIDs[0])
cloned.ModelPricing[0].Models[0] = "hacked"
require.Equal(t, "model-a", original.ModelPricing[0].Models[0])
}
func TestChannelClone_Nil(t *testing.T) {
var ch *Channel
require.Nil(t, ch.Clone())
}
func TestChannelModelPricingClone(t *testing.T) {
original := ChannelModelPricing{
Models: []string{"a", "b"},
Intervals: []PricingInterval{
{MinTokens: 0, TierLabel: "tier1"},
},
}
cloned := original.Clone()
// Modify clone slices — original unchanged
cloned.Models[0] = "hacked"
require.Equal(t, "a", original.Models[0])
cloned.Intervals[0].TierLabel = "hacked"
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
}
// --- BillingMode.IsValid ---
func TestBillingModeIsValid(t *testing.T) {
tests := []struct {
name string
mode BillingMode
want bool
}{
{"token", BillingModeToken, true},
{"per_request", BillingModePerRequest, true},
{"image", BillingModeImage, true},
{"empty", BillingMode(""), true},
{"unknown", BillingMode("unknown"), false},
{"random", BillingMode("xyz"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.mode.IsValid())
})
}
}
// --- Channel.IsActive ---
func TestChannelIsActive(t *testing.T) {
tests := []struct {
name string
status string
want bool
}{
{"active", StatusActive, true},
{"disabled", "disabled", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := &Channel{Status: tt.status}
require.Equal(t, tt.want, ch.IsActive())
})
}
}
// --- ChannelModelPricing.Clone edge cases ---
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
t.Run("nil models", func(t *testing.T) {
original := ChannelModelPricing{Models: nil}
cloned := original.Clone()
require.Nil(t, cloned.Models)
})
t.Run("nil intervals", func(t *testing.T) {
original := ChannelModelPricing{Intervals: nil}
cloned := original.Clone()
require.Nil(t, cloned.Intervals)
})
t.Run("empty models", func(t *testing.T) {
original := ChannelModelPricing{Models: []string{}}
cloned := original.Clone()
require.NotNil(t, cloned.Models)
require.Empty(t, cloned.Models)
})
}
// --- Channel.Clone edge cases ---
func TestChannelClone_EdgeCases(t *testing.T) {
t.Run("nil model mapping", func(t *testing.T) {
original := &Channel{ID: 1, ModelMapping: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelMapping)
})
t.Run("nil model pricing", func(t *testing.T) {
original := &Channel{ID: 1, ModelPricing: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelPricing)
})
t.Run("deep copy model mapping", func(t *testing.T) {
original := &Channel{
ID: 1,
ModelMapping: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
}
cloned := original.Clone()
// Modify the cloned nested map
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
// Original must remain unchanged
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
})
}
// --- ValidateIntervals ---
func TestValidateIntervals_Empty(t *testing.T) {
require.NoError(t, ValidateIntervals(nil))
require.NoError(t, ValidateIntervals([]PricingInterval{}))
}
func TestValidateIntervals_ValidIntervals(t *testing.T) {
tests := []struct {
name string
intervals []PricingInterval
}{
{
name: "single bounded interval",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
},
{
name: "two intervals with gap",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
},
{
name: "two contiguous intervals",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
},
{
name: "unsorted input (auto-sorted by validator)",
intervals: []PricingInterval{
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
},
{
name: "single unbounded interval",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.NoError(t, ValidateIntervals(tt.intervals))
})
}
}
func TestValidateIntervals_NegativeMinTokens(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "min_tokens")
require.Contains(t, err.Error(), ">= 0")
}
func TestValidateIntervals_MaxTokensZero(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> 0")
}
func TestValidateIntervals_MaxLessThanMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
}
func TestValidateIntervals_MaxEqualsMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
}
func TestValidateIntervals_NegativePrice(t *testing.T) {
negPrice := -0.01
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "input_price")
require.Contains(t, err.Error(), ">= 0")
}
func TestValidateIntervals_OverlappingIntervals(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "overlap")
}
func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last")
}

View File

@@ -0,0 +1,130 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(1), account.ID)
}
func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
}

View File

@@ -0,0 +1,293 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// --- billingModelForRestriction ---
func TestBillingModelForRestriction_Requested(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-5", got)
}
func TestBillingModelForRestriction_ChannelMapped(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestBillingModelForRestriction_Upstream(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "", got, "upstream should return empty (per-account check needed)")
}
func TestBillingModelForRestriction_Empty(t *testing.T) {
t.Parallel()
got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped")
}
// --- resolveAccountUpstreamModel ---
func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
// Antigravity 平台使用 DefaultAntigravityModelMapping
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
got := resolveAccountUpstreamModel(account, "totally-unknown-model")
require.Equal(t, "", got, "unsupported model should return empty")
}
func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAnthropic,
}
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough")
}
// --- checkChannelPricingRestriction ---
func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) {
t.Parallel()
svc := &GatewayService{}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, ""))
}
func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6但定价列表只有 claude-opus-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6定价列表包含 claude-sonnet-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) {
t.Parallel()
// billing_model_source=requested定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model claude-sonnet-4-5 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) {
t.Parallel()
// upstream 模式:预检查始终跳过(返回 false需逐账号检查
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"),
"upstream mode should skip pre-check (per-account check needed)")
}
func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: false, // 未开启模型限制
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"RestrictModels=false → always allowed")
}
func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) {
t.Parallel()
// 分组没有关联渠道
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil },
}
channelSvc := newTestChannelService(repo)
svc := &GatewayService{channelService: channelSvc}
gid := int64(999)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"no channel for group → allowed")
}
// --- isUpstreamModelRestrictedByChannel ---
func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6
// 但定价列表只有 claude-opus-4-6
require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 NOT in pricing → restricted")
}
func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"),
"unmappable model → upstream model empty → not restricted (account filter handles this)")
}

View File

@@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
modelsListCacheTTL: time.Minute,
}
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)

View File

@@ -2031,7 +2031,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2084,7 +2084,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // legacy path
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2116,7 +2116,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2148,7 +2148,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2182,7 +2182,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2218,7 +2218,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2259,7 +2259,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2287,7 +2287,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.Error(t, err)
require.Nil(t, result)
require.ErrorIs(t, err, ErrNoAvailableAccounts)
@@ -2319,7 +2319,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2352,7 +2352,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2390,7 +2390,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
@@ -2426,7 +2426,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2485,7 +2485,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
@@ -2539,7 +2539,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2593,7 +2593,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2651,7 +2651,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2709,7 +2709,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
@@ -2767,7 +2767,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2804,7 +2804,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
@@ -2856,7 +2856,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2934,7 +2934,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excluded := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -2988,7 +2988,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -3021,7 +3021,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.Error(t, err)
require.Nil(t, result)
require.ErrorIs(t, err, ErrClaudeCodeOnly)
@@ -3059,7 +3059,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
@@ -3097,7 +3097,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)

View File

@@ -41,6 +41,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
nil,
)
}

View File

@@ -60,6 +60,13 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -483,6 +490,7 @@ type ClaudeUsage struct {
CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token来自嵌套 cache_creation 对象)
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// ForwardResult 转发结果
@@ -568,6 +576,8 @@ type GatewayService struct {
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
channelService *ChannelService
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
}
@@ -597,6 +607,8 @@ func NewGatewayService(
digestStore *DigestSessionStore,
settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
@@ -629,6 +641,8 @@ func NewGatewayService(
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
@@ -866,17 +880,7 @@ type anthropicMetadataPayload struct {
// replaceModelInBody 替换请求体中的model字段
// 优先使用定点修改,尽量保持客户端原始字段顺序。
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 {
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
return ReplaceModelInBody(body, newModel)
}
type claudeOAuthNormalizeOptions struct {
@@ -1186,6 +1190,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
platform = PlatformAnthropic
}
// Claude Code 限制可能已将 groupID 解析为 fallback group
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
@@ -1198,8 +1211,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
@@ -1220,6 +1235,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx = s.withGroupContext(ctx, group)
// Claude Code 限制可能已将 groupID 解析为 fallback group
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
@@ -2945,6 +2969,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
// 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
acc := &accounts[i]
@@ -2965,6 +2992,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
@@ -3197,6 +3227,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
acc := &accounts[i]
@@ -3221,6 +3253,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
@@ -7410,6 +7445,8 @@ type RecordUsageInput struct {
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
@@ -7439,6 +7476,18 @@ type postUsageBillingParams struct {
APIKeyService APIKeyQuotaUpdater
}
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
@@ -7468,21 +7517,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
// 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
@@ -7564,13 +7613,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.BalanceCost = p.Cost.ActualCost
}
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if p.shouldDeductAPIKeyQuota() {
cmd.APIKeyQuotaCost = p.Cost.ActualCost
}
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if p.shouldUpdateRateLimits() {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
}
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
if p.shouldUpdateAccountQuota() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
}
@@ -7694,191 +7743,41 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
}
}
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct {
// Claude Max 策略所需的 ParsedRequest可选仅 Claude 路径传入)
ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支image/video/prompt
// - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
LongContextThreshold int
LongContextMultiplier float64
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil {
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
} else if result.MediaType == "prompt" {
cost = &CostBreakdown{}
} else if result.ImageCount > 0 {
// 图片生成计费
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
var mediaType *string
if strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil
return s.recordUsageCore(ctx, &recordUsageCoreInput{
Result: input.Result,
APIKey: input.APIKey,
User: input.User,
Account: input.Account,
Subscription: input.Subscription,
InboundEndpoint: input.InboundEndpoint,
UpstreamEndpoint: input.UpstreamEndpoint,
UserAgent: input.UserAgent,
IPAddress: input.IPAddress,
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
EnableClaudePath: true,
})
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
@@ -7897,10 +7796,55 @@ type RecordUsageLongContextInput struct {
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
return s.recordUsageCore(ctx, &recordUsageCoreInput{
Result: input.Result,
APIKey: input.APIKey,
User: input.User,
Account: input.Account,
Subscription: input.Subscription,
InboundEndpoint: input.InboundEndpoint,
UpstreamEndpoint: input.UpstreamEndpoint,
UserAgent: input.UserAgent,
IPAddress: input.IPAddress,
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
LongContextThreshold: input.LongContextThreshold,
LongContextMultiplier: input.LongContextMultiplier,
})
}
// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。
type recordUsageCoreInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
InboundEndpoint string
UpstreamEndpoint string
UserAgent string
IPAddress string
RequestPayloadHash string
ForceCacheBilling bool
APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
}
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支image/video/prompt
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
apiKey := input.APIKey
user := input.User
@@ -7933,38 +7877,23 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
// 图片生成计费
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
// 确定 RequestedModel渠道映射前的原始模型
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
// 计算费用
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
@@ -7974,12 +7903,214 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
requestID := usageLog.RequestID
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil
}
// calculateRecordUsageCost 根据请求类型和选项计算费用。
func (s *GatewayService) calculateRecordUsageCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
}
// Token 计费
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
if s.resolver == nil || apiKey.Group == nil {
return nil
}
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
return resolved
}
return nil
}
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
func (s *GatewayService) calculateImageCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
gid := apiKey.Group.ID
cost, err := s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
}
return cost
}
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
}
// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。
func (s *GatewayService) calculateTokenCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
var cost *CostBreakdown
var err error
// 优先尝试渠道定价 → CalculateCostUnified
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
})
} else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(
billingModel, tokens, multiplier,
opts.LongContextThreshold, opts.LongContextMultiplier,
)
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
}
return cost
}
// buildRecordUsageLog 构建使用日志并设置计费模式。
func (s *GatewayService) buildRecordUsageLog(
ctx context.Context,
input *recordUsageCoreInput,
result *ForwardResult,
apiKey *APIKey,
user *User,
account *Account,
subscription *UserSubscription,
requestedModel string,
multiplier float64,
accountRateMultiplier float64,
billingType int8,
cacheTTLOverridden bool,
cost *CostBreakdown,
opts *recordUsageOpts,
) *UsageLog {
durationMs := int(result.Duration.Milliseconds())
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
@@ -7987,7 +8118,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@@ -7998,72 +8129,170 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
ImageOutputTokens: result.Usage.ImageOutputTokens,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
GroupID: apiKey.GroupID,
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
if cost != nil {
usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost
usageLog.ImageOutputCost = cost.ImageOutputCost
usageLog.CacheCreationCost = cost.CacheCreationCost
usageLog.CacheReadCost = cost.CacheReadCost
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
return usageLog
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if isSoraMedia {
return nil
}
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
var mode string
switch {
case cost != nil && cost.BillingMode != "":
mode = cost.BillingMode
case result.ImageCount > 0:
mode = string(BillingModeImage)
default:
mode = string(BillingModeToken)
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return &mode
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
}
return nil
}
// ResolveChannelMapping 委托渠道服务解析模型映射
func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel)
}
// IsModelRestricted 检查模型是否被渠道限制
func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段checkChannelPricingRestrictionrestricted 始终返回 false。
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
// 供调度阶段预检查requested / channel_mapped
// upstream 需逐账号检查,此处返回 false。
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
// upstream 返回空(需逐账号检查)。
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
switch source {
case BillingModelSourceRequested:
return requestedModel
case BillingModelSourceUpstream:
return ""
case BillingModelSourceChannelMapped:
return channelMappedModel
default:
return channelMappedModel
}
}
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
if account.Platform == PlatformAntigravity {
return mapAntigravityModel(account, requestedModel)
}
return account.GetMappedModel(requestedModel)
}
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {

View File

@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int())
thoughts := int(usage.Get("thoughtsTokenCount").Int())
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
imageTokens := 0
candidateDetails := usage.Get("candidatesTokensDetails")
if candidateDetails.Exists() {
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
if detail.Get("modality").String() == "IMAGE" {
imageTokens = int(detail.Get("tokenCount").Int())
return false
}
return true
})
}
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
return &ClaudeUsage{
InputTokens: prompt - cached,
OutputTokens: cand + thoughts,
CacheReadInputTokens: cached,
ImageOutputTokens: imageTokens,
}
}

View File

@@ -0,0 +1,231 @@
package service
import (
"context"
"log/slog"
)
// PricingSource 定价来源标识
const (
PricingSourceChannel = "channel"
PricingSourceLiteLLM = "litellm"
PricingSourceFallback = "fallback"
)
// ResolvedPricing 统一定价解析结果
type ResolvedPricing struct {
// Mode 计费模式
Mode BillingMode
// Token 模式:基础定价(来自 LiteLLM 或 fallback
BasePricing *ModelPricing
// Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段)
Intervals []PricingInterval
// 按次/图片模式:分层定价
RequestTiers []PricingInterval
// 按次/图片模式:默认价格(未命中层级时使用)
DefaultPerRequestPrice float64
// 来源标识
Source string // "channel", "litellm", "fallback"
// 是否支持缓存细分
SupportsCacheBreakdown bool
}
// ModelPricingResolver 统一模型定价解析器。
// 解析链Channel → LiteLLM → Fallback。
type ModelPricingResolver struct {
channelService *ChannelService
billingService *BillingService
}
// NewModelPricingResolver 创建定价解析器实例
func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver {
return &ModelPricingResolver{
channelService: channelService,
billingService: billingService,
}
}
// PricingInput 定价解析输入
type PricingInput struct {
Model string
GroupID *int64 // nil 表示不检查渠道
}
// Resolve 解析模型定价。
// 1. 获取基础定价LiteLLM → Fallback
// 2. 如果指定了 GroupID查找渠道定价并覆盖
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
// 1. 获取基础定价
basePricing, source := r.resolveBasePricing(input.Model)
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: basePricing,
Source: source,
SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown,
}
// 2. 如果有 GroupID尝试渠道覆盖
if input.GroupID != nil {
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
}
return resolved
}
// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价
func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) {
pricing, err := r.billingService.GetModelPricing(model)
if err != nil {
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
"model", model, "error", err)
return nil, PricingSourceFallback
}
return pricing, PricingSourceLiteLLM
}
// applyChannelOverrides 应用渠道定价覆盖
func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) {
chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model)
if chPricing == nil {
return
}
resolved.Source = PricingSourceChannel
resolved.Mode = chPricing.BillingMode
if resolved.Mode == "" {
resolved.Mode = BillingModeToken
}
switch resolved.Mode {
case BillingModeToken:
r.applyTokenOverrides(chPricing, resolved)
case BillingModePerRequest, BillingModeImage:
r.applyRequestTierOverrides(chPricing, resolved)
}
}
// applyTokenOverrides 应用 token 模式的渠道覆盖
func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
// 过滤掉所有价格字段都为空的无效 interval
validIntervals := filterValidIntervals(chPricing.Intervals)
// 如果有有效的区间定价,使用区间
if len(validIntervals) > 0 {
resolved.Intervals = validIntervals
return
}
// 否则用 flat 字段覆盖 BasePricing
if resolved.BasePricing == nil {
resolved.BasePricing = &ModelPricing{}
}
if chPricing.InputPrice != nil {
resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice
resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice
}
if chPricing.OutputPrice != nil {
resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice
resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice
}
if chPricing.CacheWritePrice != nil {
resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice
resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice
resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice
}
if chPricing.CacheReadPrice != nil {
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
}
if chPricing.ImageOutputPrice != nil {
resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
}
}
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
resolved.RequestTiers = filterValidIntervals(chPricing.Intervals)
if chPricing.PerRequestPrice != nil {
resolved.DefaultPerRequestPrice = *chPricing.PerRequestPrice
}
}
// filterValidIntervals 过滤掉所有价格字段都为空的无效 interval。
// 前端可能创建了只有 min/max 但无价格的空 interval。
func filterValidIntervals(intervals []PricingInterval) []PricingInterval {
var valid []PricingInterval
for _, iv := range intervals {
if iv.InputPrice != nil || iv.OutputPrice != nil ||
iv.CacheWritePrice != nil || iv.CacheReadPrice != nil ||
iv.PerRequestPrice != nil {
valid = append(valid, iv)
}
}
return valid
}
// GetIntervalPricing 根据 context token 数获取区间定价。
// 如果有区间列表,找到匹配区间并构造 ModelPricing否则直接返回 BasePricing。
func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing {
if len(resolved.Intervals) == 0 {
return resolved.BasePricing
}
iv := FindMatchingInterval(resolved.Intervals, totalContextTokens)
if iv == nil {
return resolved.BasePricing
}
return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown)
}
// intervalToModelPricing 将区间定价转换为 ModelPricing
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing {
pricing := &ModelPricing{
SupportsCacheBreakdown: supportsCacheBreakdown,
}
if iv.InputPrice != nil {
pricing.InputPricePerToken = *iv.InputPrice
pricing.InputPricePerTokenPriority = *iv.InputPrice
}
if iv.OutputPrice != nil {
pricing.OutputPricePerToken = *iv.OutputPrice
pricing.OutputPricePerTokenPriority = *iv.OutputPrice
}
if iv.CacheWritePrice != nil {
pricing.CacheCreationPricePerToken = *iv.CacheWritePrice
pricing.CacheCreation5mPrice = *iv.CacheWritePrice
pricing.CacheCreation1hPrice = *iv.CacheWritePrice
}
if iv.CacheReadPrice != nil {
pricing.CacheReadPricePerToken = *iv.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice
}
return pricing
}
// GetRequestTierPrice 根据层级标签获取按次价格
func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 {
for _, tier := range resolved.RequestTiers {
if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil {
return *tier.PerRequestPrice
}
}
return 0
}
// GetRequestTierPriceByContext 根据 context token 数获取按次价格
func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 {
iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens)
if iv != nil && iv.PerRequestPrice != nil {
return *iv.PerRequestPrice
}
return 0
}

View File

@@ -0,0 +1,663 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func newTestBillingServiceForResolver() *BillingService {
bs := &BillingService{
fallbackPrices: make(map[string]*ModelPricing),
}
bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
CacheCreationPricePerToken: 3.75e-6,
CacheReadPricePerToken: 0.3e-6,
SupportsCacheBreakdown: false,
}
return bs
}
func TestResolve_NoGroupID(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: nil,
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeToken, resolved.Mode)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
require.Equal(t, "litellm", resolved.Source)
}
func TestResolve_UnknownModel(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "unknown-model-xyz",
GroupID: nil,
})
require.NotNil(t, resolved)
require.Nil(t, resolved.BasePricing)
// Unknown model: GetModelPricing returns error, source is "fallback"
require.Equal(t, "fallback", resolved.Source)
}
func TestGetIntervalPricing_NoIntervals(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
basePricing := &ModelPricing{InputPricePerToken: 5e-6}
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: basePricing,
Intervals: nil,
}
result := r.GetIntervalPricing(resolved, 50000)
require.Equal(t, basePricing, result)
}
func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
SupportsCacheBreakdown: true,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)},
},
}
result := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, result)
require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12)
require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12)
require.True(t, result.SupportsCacheBreakdown)
result2 := r.GetIntervalPricing(resolved, 200000)
require.NotNil(t, result2)
require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12)
}
func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
basePricing := &ModelPricing{InputPricePerToken: 99e-6}
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: basePricing,
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)},
},
}
result := r.GetIntervalPricing(resolved, 5000)
require.Equal(t, basePricing, result)
}
func TestGetRequestTierPrice(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
},
}
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
}
func TestGetRequestTierPriceByContext(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}
require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
}
func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: nil},
},
}
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
}
// ===========================================================================
// Channel override tests — exercises applyChannelOverrides via Resolve
// ===========================================================================
// helper: creates a resolver wired to a ChannelService that returns the given
// channel (active, groupID=100, platform=anthropic) with the specified pricing.
func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver {
t.Helper()
const groupID = 100
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return []Channel{{
ID: 1,
Name: "test-channel",
Status: StatusActive,
GroupIDs: []int64{groupID},
ModelPricing: pricing,
}}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return map[int64]string{groupID: "anthropic"}, nil
},
}
cs := NewChannelService(repo, nil)
bs := newTestBillingServiceForResolver()
return NewModelPricingResolver(cs, bs)
}
// groupIDPtr returns a pointer to groupID 100 (the test constant).
func groupIDPtr() *int64 { v := int64(100); return &v }
// ---------------------------------------------------------------------------
// 1. Token mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(50e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeToken, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12)
}
func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) {
// Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback).
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(20e-6),
// OutputPrice intentionally nil
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
require.NotNil(t, resolved.BasePricing)
// InputPrice overridden by channel
require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
// OutputPrice kept from base (fallback: 15e-6)
require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
}
func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
require.Len(t, resolved.Intervals, 2)
// GetIntervalPricing should use channel intervals
iv := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, iv)
require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12)
require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12)
iv2 := r.GetIntervalPricing(resolved, 200000)
require.NotNil(t, iv2)
require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12)
require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12)
}
func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) {
// Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"unknown-model-xyz"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(7e-6),
OutputPrice: testPtrFloat64(21e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "unknown-model-xyz",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
// BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
}
// ---------------------------------------------------------------------------
// 2. Per-request mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_PerRequest(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModePerRequest, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 2)
// Verify tier lookups
require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
}
func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) {
// PerRequestPrice nil → DefaultPerRequestPrice stays 0.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModePerRequest,
// PerRequestPrice intentionally nil
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModePerRequest, resolved.Mode)
require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 1)
}
// ---------------------------------------------------------------------------
// 3. Image mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_Image(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.08),
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeImage, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 3)
}
func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeImage,
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found
}
// ---------------------------------------------------------------------------
// 4. Source tracking & default mode
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(1e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.Equal(t, "channel", resolved.Source)
}
func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) {
// Channel pricing with empty BillingMode → defaults to BillingModeToken.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: "", // intentionally empty
InputPrice: testPtrFloat64(5e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.Equal(t, "channel", resolved.Source)
require.Equal(t, BillingModeToken, resolved.Mode)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
}
// ---------------------------------------------------------------------------
// 5. GetIntervalPricing integration after channel override
// ---------------------------------------------------------------------------
func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) {
// Channel provides intervals that override the base pricing path.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)},
{MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
// Token count 50000 matches first interval
pricing := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, pricing)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
// Token count 150000 matches second interval
pricing2 := r.GetIntervalPricing(resolved, 150000)
require.NotNil(t, pricing2)
require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12)
}
func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) {
// Channel intervals don't match token count → falls back to BasePricing.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
// Only covers tokens > 50000
{MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
// Token count 1000 doesn't match any interval (1000 <= 50000 minTokens)
pricing := r.GetIntervalPricing(resolved, 1000)
// Should fall back to BasePricing (from the billing service fallback)
require.NotNil(t, pricing)
require.Equal(t, resolved.BasePricing, pricing)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price
}
// ===========================================================================
// 6. Error path tests
// ===========================================================================
func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
// When ListAll returns an error, the ChannelService cache build fails.
// Resolve should gracefully fall back to base pricing without panicking.
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, errors.New("database unavailable")
},
}
cs := NewChannelService(repo, nil)
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(cs, bs)
gid := int64(100)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: &gid,
})
require.NotNil(t, resolved)
// Should NOT panic, should NOT have source "channel"
require.NotEqual(t, "channel", resolved.Source)
// Base pricing should still be present (from BillingService fallback)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
}
// ===========================================================================
// 7. GetRequestTierPriceByContext boundary tests
// ===========================================================================
func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: nil, // empty
}
price := r.GetRequestTierPriceByContext(resolved, 50000)
require.InDelta(t, 0.0, price, 1e-12)
// Also test with explicit empty slice
resolved2 := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{},
}
price2 := r.GetRequestTierPriceByContext(resolved2, 50000)
require.InDelta(t, 0.0, price2, 1e-12)
}
func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}
// totalContextTokens = 128000 exactly:
// FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens
// For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval
price := r.GetRequestTierPriceByContext(resolved, 128000)
require.InDelta(t, 0.05, price, 1e-12)
// totalContextTokens = 128001 should match second interval
// For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match
// For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches
price2 := r.GetRequestTierPriceByContext(resolved, 128001)
require.InDelta(t, 0.10, price2, 1e-12)
}
// ===========================================================================
// 8. filterValidIntervals
// ===========================================================================
func TestFilterValidIntervals(t *testing.T) {
tests := []struct {
name string
intervals []PricingInterval
wantLen int
}{
{
name: "empty list",
intervals: nil,
wantLen: 0,
},
{
name: "all-nil interval filtered out",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000)},
},
wantLen: 0,
},
{
name: "interval with only InputPrice kept",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
wantLen: 1,
},
{
name: "interval with only OutputPrice kept",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), OutputPrice: testPtrFloat64(2e-6)},
},
wantLen: 1,
},
{
name: "interval with only CacheWritePrice kept",
intervals: []PricingInterval{
{MinTokens: 0, CacheWritePrice: testPtrFloat64(3e-6)},
},
wantLen: 1,
},
{
name: "interval with only CacheReadPrice kept",
intervals: []PricingInterval{
{MinTokens: 0, CacheReadPrice: testPtrFloat64(0.5e-6)},
},
wantLen: 1,
},
{
name: "interval with only PerRequestPrice kept",
intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
},
wantLen: 1,
},
{
name: "mixed valid and invalid",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil}, // all-nil → filtered out
{MinTokens: 256000, OutputPrice: testPtrFloat64(5e-6)},
},
wantLen: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterValidIntervals(tt.intervals)
require.Len(t, result, tt.wantLen)
})
}
}

View File

@@ -0,0 +1,140 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
PlatformOpenAI: {"gpt-4.1": "o3-mini"},
},
}, map[int64]string{10: PlatformOpenAI}))
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true},
}},
channelService: channelSvc,
}
groupID := int64(10)
_, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
require.ErrorIs(t, err, ErrNoAvailableAccounts)
require.Contains(t, err.Error(), "channel pricing restriction")
}
func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
},
}, map[int64]string{10: PlatformOpenAI}))
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 10,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
},
},
{
ID: 2,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 20,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
},
},
}},
channelService: channelSvc,
}
groupID := int64(10)
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(2), account.ID)
}
func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
},
}, map[int64]string{10: PlatformOpenAI}))
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:sticky-session": 1},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 10,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
},
},
{
ID: 2,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 20,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
},
},
}},
channelService: channelSvc,
cache: cache,
}
groupID := int64(10)
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(2), account.ID)
require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"])
require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"])
}

View File

@@ -10,8 +10,8 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
switch normalizeCodexModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex":
return true
default:
return false
@@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
return ""
}
normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" {
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
}
if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model)

View File

@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
upstreamModel := normalizeCodexModel(billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false

View File

@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 3. Model mapping
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
upstreamModel := normalizeCodexModel(billingModel)
responsesReq.Model = upstreamModel
logger.L().Debug("openai messages: model mapping applied",

View File

@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
&DeferredService{},
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,

View File

@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"sort"
@@ -204,6 +205,7 @@ type OpenAIUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// OpenAIForwardResult represents the result of forwarding
@@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
return svc
}
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段restricted 始终返回 false。
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel)
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle
@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
}
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 1. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
if selected == nil {
if requestedModel != "" {
@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts {
acc := &accounts[i]
@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
cfg := s.schedulingConfig()
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok {
upstreamModel = resolveOpenAIUpstreamModel(model)
upstreamModel = normalizeCodexModel(model)
if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, upstreamModel, account.Name, account.Type, isCodexCLI)
@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
}
// RecordUsage records usage and deducts balance
@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
// Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil {
resolver := s.userGroupRateResolver
if resolver == nil {
@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
@@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
// 确定 RequestedModel渠道映射前的原始模型
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
OpenAIWSMode: result.OpenAIWSMode,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
if cost != nil {
usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost
usageLog.ImageOutputCost = cost.ImageOutputCost
usageLog.CacheCreationCost = cost.CacheCreationCost
usageLog.CacheReadCost = cost.CacheReadCost
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
usageLog.RateMultiplier = multiplier
usageLog.AccountRateMultiplier = &accountRateMultiplier
usageLog.BillingType = billingType
usageLog.Stream = result.Stream
usageLog.OpenAIWSMode = result.OpenAIWSMode
usageLog.DurationMs = &durationMs
usageLog.FirstTokenMs = result.FirstTokenMs
usageLog.CreatedAt = time.Now()
// 设置渠道信息
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
// 添加 UserAgent
if input.UserAgent != "" {

View File

@@ -1,10 +1,8 @@
package service
import "strings"
// resolveOpenAIForwardModel resolves the account/group mapping result for
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil {
if defaultMappedModel != "" {
@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
return mappedModel
}
func resolveOpenAIUpstreamModel(model string) string {
if isBareGPT53CodexSparkModel(model) {
return "gpt-5.3-codex-spark"
}
return normalizeCodexModel(strings.TrimSpace(model))
}
func isBareGPT53CodexSparkModel(model string) bool {
modelID := strings.TrimSpace(model)
if modelID == "" {
return false
}
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
normalized := strings.ToLower(strings.TrimSpace(modelID))
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
}

View File

@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials: map[string]any{},
}
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if withoutDefault != "gpt-5.1" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
}
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
if withDefault != "gpt-5.4" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4")
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
}
}
func TestResolveOpenAIUpstreamModel(t *testing.T) {
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
}
for input, expected := range cases {
if got := resolveOpenAIUpstreamModel(input); got != expected {
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected)
if got := normalizeCodexModel(input); got != expected {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected)
}
}
}

View File

@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
normalized = next
}
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel))
if upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
if setErr != nil {
@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel := ""
var mappedModelBytes []byte
if originalModel != "" {
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel))
needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace {
mappedModelBytes = []byte(mappedModel)

View File

@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)

View File

@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available")
}
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
}

View File

@@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
}
// PricingRemoteClient 远程价格数据获取接口
@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
}
// PricingService 动态价格服务
@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.OutputCostPerImage != nil {
pricing.OutputCostPerImage = *entry.OutputCostPerImage
}
if entry.OutputCostPerImageToken != nil {
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
}
result[modelName] = pricing
}

View File

@@ -0,0 +1,15 @@
//go:build unit
package service
// testPtrFloat64 returns a pointer to the given float64 value.
func testPtrFloat64(v float64) *float64 { return &v }
// testPtrInt returns a pointer to the given int value.
func testPtrInt(v int) *int { return &v }
// testPtrString returns a pointer to the given string value.
func testPtrString(v string) *string { return &v }
// testPtrBool returns a pointer to the given bool value.
func testPtrBool(v bool) *bool { return &v }

View File

@@ -104,6 +104,14 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签per_request/image 模式)
BillingTier *string
// BillingMode 计费模式token/imagesora 路径为 nil
BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string
// ReasoningEffort is the request's reasoning effort level.
@@ -126,6 +134,9 @@ type UsageLog struct {
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
ImageOutputTokens int
ImageOutputCost float64
InputCost float64
OutputCost float64
CacheCreationCost float64

View File

@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
}
return strings.TrimSpace(upstreamModel)
}
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}

View File

@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService,
ProvideScheduledTestRunnerService,
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
)

View File

@@ -0,0 +1,56 @@
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 渠道表
CREATE TABLE IF NOT EXISTS channels (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
description TEXT DEFAULT '',
status VARCHAR(20) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- 渠道名称唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE TABLE IF NOT EXISTS channel_groups (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE TABLE IF NOT EXISTS channel_model_pricing (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
models JSONB NOT NULL DEFAULT '[]',
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
image_output_price NUMERIC(20,8),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格USDNULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格USDNULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格Gemini Image 等NULL 表示使用默认';

View File

@@ -0,0 +1,67 @@
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式token按 token 区间计费、per_request按次计费、image图片计费';
-- 2. 创建区间定价子表
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
id BIGSERIAL PRIMARY KEY,
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
min_tokens INT NOT NULL DEFAULT 0,
max_tokens INT,
tier_label VARCHAR(50),
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
per_request_price NUMERIC(20,12),
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
ON channel_pricing_intervals (pricing_id);
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界token 模式使用';
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界不含NULL 表示无上限';
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD';
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL的条目
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
SELECT
cmp.id,
0,
NULL,
cmp.input_price,
cmp.output_price,
cmp.cache_write_price,
cmp.cache_read_price,
0
FROM channel_model_pricing cmp
WHERE cmp.billing_mode = 'token'
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
AND NOT EXISTS (
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理

View File

@@ -0,0 +1,5 @@
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';

View File

@@ -0,0 +1,7 @@
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
-- Add channel tracking fields to usage_logs
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);

View File

@@ -0,0 +1,5 @@
-- Add model restriction switch to channels
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);

View File

@@ -0,0 +1,21 @@
-- 086_channel_platform_pricing.sql
-- 渠道按平台维度model_pricing 加 platform 列model_mapping 改为嵌套格式
-- 1. channel_model_pricing 加 platform 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
ON channel_model_pricing (platform);
-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
UPDATE channels
SET model_mapping = jsonb_build_object('anthropic', model_mapping)
WHERE model_mapping IS NOT NULL
AND model_mapping::text NOT IN ('{}', 'null', '')
AND NOT EXISTS (
SELECT 1 FROM jsonb_each(model_mapping) AS kv
WHERE jsonb_typeof(kv.value) = 'object'
LIMIT 1
);

View File

@@ -0,0 +1,2 @@
-- Add billing_mode to usage_logs (records the billing mode: token/per_request/image)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20);

View File

@@ -0,0 +1,3 @@
-- Change default billing_model_source for new channels to 'channel_mapped'
-- Existing channels keep their current setting (no UPDATE on existing rows)
ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped';

View File

@@ -0,0 +1,2 @@
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;

View File

@@ -0,0 +1,148 @@
/**
* Admin Channels API endpoints
* Handles channel management for administrators
*/
import { apiClient } from '../client'
export type BillingMode = 'token' | 'per_request' | 'image'
export interface PricingInterval {
id?: number
min_tokens: number
max_tokens: number | null
tier_label: string
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
per_request_price: number | null
sort_order: number
}
export interface ChannelModelPricing {
id?: number
platform: string
models: string[]
billing_mode: BillingMode
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
image_output_price: number | null
per_request_price: number | null
intervals: PricingInterval[]
}
export interface Channel {
id: number
name: string
description: string
status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
group_ids: number[]
model_pricing: ChannelModelPricing[]
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
created_at: string
updated_at: string
}
export interface CreateChannelRequest {
name: string
description?: string
group_ids?: number[]
model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string
restrict_models?: boolean
}
export interface UpdateChannelRequest {
name?: string
description?: string
status?: string
group_ids?: number[]
model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string
restrict_models?: boolean
}
interface PaginatedResponse<T> {
items: T[]
total: number
}
/**
* List channels with pagination
*/
export async function list(
page: number = 1,
pageSize: number = 20,
filters?: {
status?: string
search?: string
},
options?: { signal?: AbortSignal }
): Promise<PaginatedResponse<Channel>> {
const { data } = await apiClient.get<PaginatedResponse<Channel>>('/admin/channels', {
params: {
page,
page_size: pageSize,
...filters
},
signal: options?.signal
})
return data
}
/**
* Get channel by ID
*/
export async function getById(id: number): Promise<Channel> {
const { data } = await apiClient.get<Channel>(`/admin/channels/${id}`)
return data
}
/**
* Create a new channel
*/
export async function create(req: CreateChannelRequest): Promise<Channel> {
const { data } = await apiClient.post<Channel>('/admin/channels', req)
return data
}
/**
* Update a channel
*/
export async function update(id: number, req: UpdateChannelRequest): Promise<Channel> {
const { data } = await apiClient.put<Channel>(`/admin/channels/${id}`, req)
return data
}
/**
* Delete a channel
*/
export async function remove(id: number): Promise<void> {
await apiClient.delete(`/admin/channels/${id}`)
}
export interface ModelDefaultPricing {
found: boolean
input_price?: number // per-token price
output_price?: number
cache_write_price?: number
cache_read_price?: number
image_output_price?: number
}
export async function getModelDefaultPricing(model: string): Promise<ModelDefaultPricing> {
const { data } = await apiClient.get<ModelDefaultPricing>('/admin/channels/model-pricing', {
params: { model }
})
return data
}
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
export default channelsAPI

View File

@@ -167,6 +167,13 @@ export interface UserBreakdownParams {
endpoint?: string
endpoint_type?: 'inbound' | 'upstream' | 'path'
limit?: number
// Additional filter conditions
user_id?: number
api_key_id?: number
account_id?: number
request_type?: number
stream?: boolean
billing_type?: number | null
}
export interface UserBreakdownResponse {

View File

@@ -25,6 +25,7 @@ import apiKeysAPI from './apiKeys'
import scheduledTestsAPI from './scheduledTests'
import backupAPI from './backup'
import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
import channelsAPI from './channels'
/**
* Unified admin API object for convenient access
@@ -51,7 +52,8 @@ export const adminAPI = {
apiKeys: apiKeysAPI,
scheduledTests: scheduledTestsAPI,
backup: backupAPI,
tlsFingerprintProfiles: tlsFingerprintProfileAPI
tlsFingerprintProfiles: tlsFingerprintProfileAPI,
channels: channelsAPI
}
export {
@@ -76,7 +78,8 @@ export {
apiKeysAPI,
scheduledTestsAPI,
backupAPI,
tlsFingerprintProfileAPI
tlsFingerprintProfileAPI,
channelsAPI
}
export default adminAPI

View File

@@ -80,6 +80,7 @@ export interface CreateUsageCleanupTaskRequest {
export interface AdminUsageQueryParams extends UsageQueryParams {
user_id?: number
exact_total?: boolean
billing_mode?: string
}
// ==================== API Functions ====================

View File

@@ -0,0 +1,113 @@
<template>
<div class="flex items-start gap-2 rounded border p-2"
:class="isEmpty ? 'border-red-400 bg-red-50 dark:border-red-500 dark:bg-red-950/20' : 'border-gray-200 bg-white dark:border-dark-500 dark:bg-dark-700'">
<!-- Token mode: context range + prices ($/MTok) -->
<template v-if="mode === 'token'">
<div class="w-20">
<label class="text-xs text-gray-400">Min</label>
<input :value="interval.min_tokens" @input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="w-20">
<label class="text-xs text-gray-400">Max <span class="text-gray-300">()</span></label>
<input :value="interval.max_tokens ?? ''" @input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" :placeholder="'∞'" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', '输入') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$/M</span></label>
<input :value="interval.input_price" @input="emitField('input_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', '输出') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$/M</span></label>
<input :value="interval.output_price" @input="emitField('output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', '缓存W') }} <span class="text-gray-300">$/M</span></label>
<input :value="interval.cache_write_price" @input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', '缓存R') }} <span class="text-gray-300">$/M</span></label>
<input :value="interval.cache_read_price" @input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
</template>
<!-- Per-request / Image mode: tier label + context range + price -->
<template v-else>
<div class="w-24">
<label class="text-xs text-gray-400">
{{ mode === 'image' ? t('admin.channels.form.resolution', '分辨率') : t('admin.channels.form.tierLabel', '层级') }}
</label>
<input :value="interval.tier_label" @input="emitField('tier_label', ($event.target as HTMLInputElement).value)"
type="text" class="input mt-0.5 text-xs" :placeholder="mode === 'image' ? '1K / 2K / 4K' : ''" />
</div>
<div class="w-20">
<label class="text-xs text-gray-400">Min</label>
<input :value="interval.min_tokens" @input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="w-20">
<label class="text-xs text-gray-400">Max <span class="text-gray-300">()</span></label>
<input :value="interval.max_tokens ?? ''" @input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" :placeholder="'∞'" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.perRequestPrice', '单次价格') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$</span></label>
<input :value="interval.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
</template>
<button type="button" @click="emit('remove')" class="mt-4 rounded p-0.5 text-gray-400 hover:text-red-500">
<Icon name="x" size="sm" />
</button>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import Icon from '@/components/icons/Icon.vue'
import type { IntervalFormEntry } from './types'
import type { BillingMode } from '@/api/admin/channels'
const { t } = useI18n()
const props = defineProps<{
interval: IntervalFormEntry
mode: BillingMode
}>()
const emit = defineEmits<{
update: [interval: IntervalFormEntry]
remove: []
}>()
// 检测所有价格字段是否都为空
const isEmpty = computed(() => {
const iv = props.interval
return (iv.input_price == null || iv.input_price === '') &&
(iv.output_price == null || iv.output_price === '') &&
(iv.cache_write_price == null || iv.cache_write_price === '') &&
(iv.cache_read_price == null || iv.cache_read_price === '') &&
(iv.per_request_price == null || iv.per_request_price === '')
})
function emitField(field: keyof IntervalFormEntry, value: string | number | null) {
emit('update', { ...props.interval, [field]: value === '' ? null : value })
}
function toInt(val: string): number {
const n = parseInt(val, 10)
return isNaN(n) ? 0 : n
}
function toIntOrNull(val: string): number | null {
if (val === '') return null
const n = parseInt(val, 10)
return isNaN(n) ? null : n
}
</script>

View File

@@ -0,0 +1,89 @@
<template>
<div>
<!-- Tags display -->
<div class="flex flex-wrap gap-1.5 rounded-lg border border-gray-200 bg-white p-2 dark:border-dark-600 dark:bg-dark-800 min-h-[2.5rem]">
<span
v-for="(model, idx) in models"
:key="idx"
class="inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-sm"
:class="getPlatformTagClass(props.platform || '')"
>
{{ model }}
<button
type="button"
@click="removeModel(idx)"
class="ml-0.5 rounded-full p-0.5 hover:bg-primary-200 dark:hover:bg-primary-800"
>
<Icon name="x" size="xs" />
</button>
</span>
<input
ref="inputRef"
v-model="inputValue"
type="text"
class="flex-1 min-w-[120px] border-none bg-transparent text-sm outline-none placeholder:text-gray-400 dark:text-white"
:placeholder="models.length === 0 ? placeholder : ''"
@keydown.enter.prevent="addModel"
@keydown.tab.prevent="addModel"
@keydown.delete="handleBackspace"
@paste="handlePaste"
/>
</div>
<p class="mt-1 text-xs text-gray-400">
{{ t('admin.channels.form.modelInputHint', 'Press Enter to add, supports paste for batch import.') }}
</p>
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import { useI18n } from 'vue-i18n'
import Icon from '@/components/icons/Icon.vue'
import { getPlatformTagClass } from './types'
const { t } = useI18n()
const props = defineProps<{
models: string[]
placeholder?: string
platform?: string
}>()
const emit = defineEmits<{
'update:models': [models: string[]]
}>()
const inputValue = ref('')
const inputRef = ref<HTMLInputElement>()
function addModel() {
const val = inputValue.value.trim()
if (!val) return
if (!props.models.includes(val)) {
emit('update:models', [...props.models, val])
}
inputValue.value = ''
}
function removeModel(idx: number) {
const newModels = [...props.models]
newModels.splice(idx, 1)
emit('update:models', newModels)
}
function handleBackspace() {
if (inputValue.value === '' && props.models.length > 0) {
removeModel(props.models.length - 1)
}
}
function handlePaste(e: ClipboardEvent) {
e.preventDefault()
const text = e.clipboardData?.getData('text') || ''
const items = text.split(/[,\n;]+/).map(s => s.trim()).filter(Boolean)
if (items.length === 0) return
const unique = [...new Set([...props.models, ...items])]
emit('update:models', unique)
inputValue.value = ''
}
</script>

View File

@@ -0,0 +1,354 @@
<template>
<div class="rounded-lg border border-gray-200 bg-gray-50 p-3 dark:border-dark-600 dark:bg-dark-800">
<!-- Collapsed summary header (clickable) -->
<div
class="flex cursor-pointer select-none items-center gap-2"
@click="collapsed = !collapsed"
>
<Icon
:name="collapsed ? 'chevronRight' : 'chevronDown'"
size="sm"
:stroke-width="2"
class="flex-shrink-0 text-gray-400 transition-transform duration-200"
/>
<!-- Summary: model tags + billing badge -->
<div v-if="collapsed" class="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
<!-- Compact model tags (show first 3) -->
<div class="flex min-w-0 flex-1 flex-wrap items-center gap-1">
<span
v-for="(m, i) in entry.models.slice(0, 3)"
:key="i"
class="inline-flex shrink-0 rounded px-1.5 py-0.5 text-xs"
:class="getPlatformTagClass(props.platform || '')"
>
{{ m }}
</span>
<span
v-if="entry.models.length > 3"
class="whitespace-nowrap text-xs text-gray-400"
>
+{{ entry.models.length - 3 }}
</span>
<span
v-if="entry.models.length === 0"
class="text-xs italic text-gray-400"
>
{{ t('admin.channels.form.noModels', '未添加模型') }}
</span>
</div>
<!-- Billing mode badge -->
<span
class="flex-shrink-0 rounded-full bg-primary-100 px-2 py-0.5 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
>
{{ billingModeLabel }}
</span>
</div>
<!-- Expanded: show the label "Pricing Entry" or similar -->
<div v-else class="flex-1 text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.pricingEntry', '定价配置') }}
</div>
<!-- Remove button (always visible, stop propagation) -->
<button
type="button"
@click.stop="emit('remove')"
class="flex-shrink-0 rounded p-1 text-gray-400 hover:text-red-500"
>
<Icon name="trash" size="sm" />
</button>
</div>
<!-- Expandable content with transition -->
<div
class="collapsible-content"
:class="{ 'collapsible-content--collapsed': collapsed }"
>
<div class="collapsible-inner">
<!-- Header: Models + Billing Mode -->
<div class="mt-3 flex items-start gap-2">
<div class="flex-1">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.models', '模型列表') }} <span class="text-red-500">*</span>
</label>
<ModelTagInput
:models="entry.models"
:platform="props.platform"
@update:models="onModelsUpdate($event)"
:placeholder="t('admin.channels.form.modelsPlaceholder', '输入模型名后按回车添加,支持通配符 *')"
class="mt-1"
/>
</div>
<div class="w-40">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.billingMode', '计费模式') }}
</label>
<Select
:modelValue="entry.billing_mode"
@update:modelValue="emit('update', { ...entry, billing_mode: $event as BillingMode, intervals: [] })"
:options="billingModeOptions"
class="mt-1"
/>
</div>
</div>
<!-- Token mode -->
<div v-if="entry.billing_mode === 'token'">
<!-- Default prices (fallback when no interval matches) -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultPrices', '默认价格(未命中区间时使用)') }}
<span class="ml-1 font-normal text-gray-400">$/MTok</span>
</label>
<div class="mt-1 grid grid-cols-2 gap-2 sm:grid-cols-5">
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', '输入') }}</label>
<input :value="entry.input_price" @input="emitField('input_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', '输出') }}</label>
<input :value="entry.output_price" @input="emitField('output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', '缓存写入') }}</label>
<input :value="entry.cache_write_price" @input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', '缓存读取') }}</label>
<input :value="entry.cache_read_price" @input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.imageTokenPrice', '图片输出') }}</label>
<input :value="entry.image_output_price" @input="emitField('image_output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
</div>
<!-- Token intervals -->
<div class="mt-3">
<div class="flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.intervals', '上下文区间定价(可选)') }}
<span class="ml-1 font-normal text-gray-400">(min, max]</span>
</label>
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('admin.channels.form.addInterval', '添加区间') }}
</button>
</div>
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
<IntervalRow
v-for="(iv, idx) in entry.intervals"
:key="idx"
:interval="iv"
:mode="entry.billing_mode"
@update="updateInterval(idx, $event)"
@remove="removeInterval(idx)"
/>
</div>
</div>
</div>
<!-- Per-request mode -->
<div v-else-if="entry.billing_mode === 'per_request'">
<!-- Default per-request price -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultPerRequestPrice', '默认单次价格(未命中层级时使用)') }}
<span class="ml-1 font-normal text-gray-400">$</span>
</label>
<div class="mt-1 w-48">
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<!-- Tiers -->
<div class="mt-3 flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.requestTiers', '按次计费层级') }}
</label>
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('admin.channels.form.addTier', '添加层级') }}
</button>
</div>
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
<IntervalRow
v-for="(iv, idx) in entry.intervals"
:key="idx"
:interval="iv"
:mode="entry.billing_mode"
@update="updateInterval(idx, $event)"
@remove="removeInterval(idx)"
/>
</div>
<div v-else class="mt-2 rounded border border-dashed border-gray-300 p-3 text-center text-xs text-gray-400 dark:border-dark-500">
{{ t('admin.channels.form.noTiersYet', '暂无层级,点击添加配置按次计费价格') }}
</div>
</div>
<!-- Image mode -->
<div v-else-if="entry.billing_mode === 'image'">
<!-- Default image price (per-request, same as per_request mode) -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultImagePrice', '默认图片价格(未命中层级时使用)') }}
<span class="ml-1 font-normal text-gray-400">$</span>
</label>
<div class="mt-1 w-48">
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<!-- Image tiers -->
<div class="mt-3 flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.imageTiers', '图片计费层级(按次)') }}
</label>
<button type="button" @click="addImageTier" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('admin.channels.form.addTier', '添加层级') }}
</button>
</div>
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
<IntervalRow
v-for="(iv, idx) in entry.intervals"
:key="idx"
:interval="iv"
:mode="entry.billing_mode"
@update="updateInterval(idx, $event)"
@remove="removeInterval(idx)"
/>
</div>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
import IntervalRow from './IntervalRow.vue'
import ModelTagInput from './ModelTagInput.vue'
import type { PricingFormEntry, IntervalFormEntry } from './types'
import { perTokenToMTok, getPlatformTagClass } from './types'
import type { BillingMode } from '@/api/admin/channels'
import channelsAPI from '@/api/admin/channels'
const { t } = useI18n()
const props = defineProps<{
entry: PricingFormEntry
platform?: string
}>()
const emit = defineEmits<{
update: [entry: PricingFormEntry]
remove: []
}>()
// Collapse state: entries with existing models default to collapsed
const collapsed = ref(props.entry.models.length > 0)
const billingModeOptions = computed(() => [
{ value: 'token', label: 'Token' },
{ value: 'per_request', label: t('admin.channels.billingMode.perRequest', '按次') },
{ value: 'image', label: t('admin.channels.billingMode.image', '图片(按次)') }
])
const billingModeLabel = computed(() => {
const opt = billingModeOptions.value.find(o => o.value === props.entry.billing_mode)
return opt ? opt.label : props.entry.billing_mode
})
function emitField(field: keyof PricingFormEntry, value: string) {
emit('update', { ...props.entry, [field]: value === '' ? null : value })
}
function addInterval() {
const intervals = [...(props.entry.intervals || [])]
intervals.push({
min_tokens: 0, max_tokens: null, tier_label: '',
input_price: null, output_price: null, cache_write_price: null,
cache_read_price: null, per_request_price: null,
sort_order: intervals.length
})
emit('update', { ...props.entry, intervals })
}
function addImageTier() {
const intervals = [...(props.entry.intervals || [])]
const labels = ['1K', '2K', '4K', 'HD']
intervals.push({
min_tokens: 0, max_tokens: null, tier_label: labels[intervals.length] || '',
input_price: null, output_price: null, cache_write_price: null,
cache_read_price: null, per_request_price: null,
sort_order: intervals.length
})
emit('update', { ...props.entry, intervals })
}
function updateInterval(idx: number, updated: IntervalFormEntry) {
const intervals = [...(props.entry.intervals || [])]
intervals[idx] = updated
emit('update', { ...props.entry, intervals })
}
function removeInterval(idx: number) {
const intervals = [...(props.entry.intervals || [])]
intervals.splice(idx, 1)
emit('update', { ...props.entry, intervals })
}
async function onModelsUpdate(newModels: string[]) {
const oldModels = props.entry.models
emit('update', { ...props.entry, models: newModels })
// 只在新增模型且当前无价格时自动填充
const addedModels = newModels.filter(m => !oldModels.includes(m))
if (addedModels.length === 0) return
// 检查是否所有价格字段都为空
const e = props.entry
const hasPrice = e.input_price != null || e.output_price != null ||
e.cache_write_price != null || e.cache_read_price != null
if (hasPrice) return
// 查询第一个新增模型的默认价格
try {
const result = await channelsAPI.getModelDefaultPricing(addedModels[0])
if (result.found) {
emit('update', {
...props.entry,
models: newModels,
input_price: perTokenToMTok(result.input_price ?? null),
output_price: perTokenToMTok(result.output_price ?? null),
cache_write_price: perTokenToMTok(result.cache_write_price ?? null),
cache_read_price: perTokenToMTok(result.cache_read_price ?? null),
image_output_price: perTokenToMTok(result.image_output_price ?? null),
})
}
} catch {
// 查询失败不影响用户操作
}
}
</script>
<style scoped>
.collapsible-content {
display: grid;
grid-template-rows: 1fr;
transition: grid-template-rows 0.25s ease;
}
.collapsible-content--collapsed {
grid-template-rows: 0fr;
}
.collapsible-inner {
overflow: hidden;
}
</style>

View File

@@ -0,0 +1,190 @@
import type { BillingMode, PricingInterval } from '@/api/admin/channels'
export interface IntervalFormEntry {
min_tokens: number
max_tokens: number | null
tier_label: string
input_price: number | string | null
output_price: number | string | null
cache_write_price: number | string | null
cache_read_price: number | string | null
per_request_price: number | string | null
sort_order: number
}
export interface PricingFormEntry {
models: string[]
billing_mode: BillingMode
input_price: number | string | null
output_price: number | string | null
cache_write_price: number | string | null
cache_read_price: number | string | null
image_output_price: number | string | null
per_request_price: number | string | null
intervals: IntervalFormEntry[]
}
// 价格转换:后端存 per-token前端显示 per-MTok ($/1M tokens)
const MTOK = 1_000_000
export function toNullableNumber(val: number | string | null | undefined): number | null {
if (val === null || val === undefined || val === '') return null
const num = Number(val)
return isNaN(num) ? null : num
}
/** 前端显示值($/MTok) → 后端存储值(per-token) */
export function mTokToPerToken(val: number | string | null | undefined): number | null {
const num = toNullableNumber(val)
return num === null ? null : parseFloat((num / MTOK).toPrecision(10))
}
/** 后端存储值(per-token) → 前端显示值($/MTok) */
export function perTokenToMTok(val: number | null | undefined): number | null {
if (val === null || val === undefined) return null
// toPrecision(10) 消除 IEEE 754 浮点乘法精度误差,如 5e-8 * 1e6 = 0.04999...96 → 0.05
return parseFloat((val * MTOK).toPrecision(10))
}
export function apiIntervalsToForm(intervals: PricingInterval[]): IntervalFormEntry[] {
return (intervals || []).map(iv => ({
min_tokens: iv.min_tokens,
max_tokens: iv.max_tokens,
tier_label: iv.tier_label || '',
input_price: perTokenToMTok(iv.input_price),
output_price: perTokenToMTok(iv.output_price),
cache_write_price: perTokenToMTok(iv.cache_write_price),
cache_read_price: perTokenToMTok(iv.cache_read_price),
per_request_price: iv.per_request_price,
sort_order: iv.sort_order
}))
}
export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInterval[] {
return (intervals || []).map(iv => ({
min_tokens: iv.min_tokens,
max_tokens: iv.max_tokens,
tier_label: iv.tier_label,
input_price: mTokToPerToken(iv.input_price),
output_price: mTokToPerToken(iv.output_price),
cache_write_price: mTokToPerToken(iv.cache_write_price),
cache_read_price: mTokToPerToken(iv.cache_read_price),
per_request_price: toNullableNumber(iv.per_request_price),
sort_order: iv.sort_order
}))
}
// ── 模型模式冲突检测 ──────────────────────────────────────
interface ModelPattern {
pattern: string
prefix: string // lowercase, 通配符去掉尾部 *
wildcard: boolean
}
function toModelPattern(model: string): ModelPattern {
const lower = model.toLowerCase()
const wildcard = lower.endsWith('*')
return {
pattern: model,
prefix: wildcard ? lower.slice(0, -1) : lower,
wildcard,
}
}
function patternsConflict(a: ModelPattern, b: ModelPattern): boolean {
if (!a.wildcard && !b.wildcard) return a.prefix === b.prefix
if (a.wildcard && !b.wildcard) return b.prefix.startsWith(a.prefix)
if (!a.wildcard && b.wildcard) return a.prefix.startsWith(b.prefix)
// 双通配符:任一前缀是另一前缀的前缀即冲突
return a.prefix.startsWith(b.prefix) || b.prefix.startsWith(a.prefix)
}
/** 检测模型模式列表中的冲突,返回冲突的两个模式名;无冲突返回 null */
export function findModelConflict(models: string[]): [string, string] | null {
const patterns = models.map(toModelPattern)
for (let i = 0; i < patterns.length; i++) {
for (let j = i + 1; j < patterns.length; j++) {
if (patternsConflict(patterns[i], patterns[j])) {
return [patterns[i].pattern, patterns[j].pattern]
}
}
}
return null
}
// ── 区间校验 ──────────────────────────────────────────────
/** 校验区间列表的合法性,返回错误消息;通过则返回 null */
export function validateIntervals(intervals: IntervalFormEntry[]): string | null {
if (!intervals || intervals.length === 0) return null
// 按 min_tokens 排序(不修改原数组)
const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens)
for (let i = 0; i < sorted.length; i++) {
const err = validateSingleInterval(sorted[i], i)
if (err) return err
}
return checkIntervalOverlap(sorted)
}
function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null {
if (iv.min_tokens < 0) {
return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数`
}
if (iv.max_tokens != null) {
if (iv.max_tokens <= 0) {
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0`
}
if (iv.max_tokens <= iv.min_tokens) {
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})`
}
}
return validateIntervalPrices(iv, idx)
}
function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null {
const prices: [string, number | string | null][] = [
['输入价格', iv.input_price],
['输出价格', iv.output_price],
['缓存写入价格', iv.cache_write_price],
['缓存读取价格', iv.cache_read_price],
['单次价格', iv.per_request_price],
]
for (const [name, val] of prices) {
if (val != null && val !== '' && Number(val) < 0) {
return `区间 #${idx + 1}: ${name}不能为负数`
}
}
return null
}
function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null {
for (let i = 0; i < sorted.length; i++) {
// 无上限区间必须是最后一个
if (sorted[i].max_tokens == null && i < sorted.length - 1) {
return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个`
}
if (i === 0) continue
const prev = sorted[i - 1]
// (min, max] 语义:前一个区间上界 > 当前区间下界则重叠
if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) {
const prevMax = prev.max_tokens == null ? '∞' : String(prev.max_tokens)
return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})`
}
}
return null
}
/** 平台对应的模型 tag 样式(背景+文字) */
export function getPlatformTagClass(platform: string): string {
switch (platform) {
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
case 'sora': return 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400'
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
}
}

View File

@@ -133,6 +133,12 @@
<Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" />
</div>
<!-- Billing Mode Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.billingMode') }}</label>
<Select v-model="filters.billing_mode" :options="billingModeOptions" @change="emitChange" />
</div>
<!-- Group Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.group') }}</label>
@@ -232,6 +238,13 @@ const billingTypeOptions = ref<SelectOption[]>([
{ value: 1, label: t('admin.usage.billingTypeSubscription') }
])
const billingModeOptions = ref<SelectOption[]>([
{ value: null, label: t('admin.usage.allBillingModes') },
{ value: 'token', label: t('admin.usage.billingModeToken') },
{ value: 'per_request', label: t('admin.usage.billingModePerRequest') },
{ value: 'image', label: t('admin.usage.billingModeImage') }
])
const emitChange = () => emit('change')
const debounceUserSearch = () => {

View File

@@ -26,7 +26,15 @@
</template>
<template #cell-model="{ row }">
<div v-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
<div v-if="row.model_mapping_chain && row.model_mapping_chain.includes('→')" class="space-y-0.5 text-xs">
<div v-for="(step, i) in row.model_mapping_chain.split('→')" :key="i"
class="break-all"
:class="i === 0 ? 'font-medium text-gray-900 dark:text-white' : 'text-gray-500 dark:text-gray-400'"
:style="i > 0 ? `padding-left: ${i * 0.75}rem` : ''">
<span v-if="i > 0" class="mr-0.5"></span>{{ step }}
</div>
</div>
<div v-else-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
<div class="break-all font-medium text-gray-900 dark:text-white">
{{ row.model }}
</div>
@@ -69,9 +77,15 @@
</span>
</template>
<template #cell-billing_mode="{ row }">
<span class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium" :class="getBillingModeBadgeClass(row.billing_mode)">
{{ getBillingModeLabel(row.billing_mode) }}
</span>
</template>
<template #cell-tokens="{ row }">
<!-- 图片生成请求 -->
<div v-if="row.image_count > 0" class="flex items-center gap-1.5">
<!-- 图片生成请求仅按次计费时显示图片格式 -->
<div v-if="row.image_count > 0 && row.billing_mode === 'image'" class="flex items-center gap-1.5">
<svg class="h-4 w-4 text-indigo-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z" />
</svg>
@@ -281,11 +295,11 @@
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.rate') }}</span>
<span class="font-semibold text-blue-400">{{ (tooltipData?.rate_multiplier || 1).toFixed(2) }}x</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
<span class="font-semibold text-blue-400">{{ (tooltipData?.account_rate_multiplier ?? 1).toFixed(2) }}x</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.original') }}</span>
@@ -312,6 +326,7 @@
import { ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { formatDateTime, formatReasoningEffort } from '@/utils/format'
import { formatCacheTokens, formatMultiplier } from '@/utils/formatters'
import { formatTokenPricePerMillion } from '@/utils/usagePricing'
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
import { resolveUsageRequestType } from '@/utils/usageRequestType'
@@ -350,12 +365,19 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => {
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
}
const formatCacheTokens = (tokens: number): string => {
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
return tokens.toString()
const getBillingModeLabel = (mode: string | null | undefined): string => {
if (mode === 'per_request') return t('admin.usage.billingModePerRequest')
if (mode === 'image') return t('admin.usage.billingModeImage')
return t('admin.usage.billingModeToken')
}
const getBillingModeBadgeClass = (mode: string | null | undefined): string => {
if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200'
if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
}
const formatUserAgent = (ua: string): string => {
return ua
}

View File

@@ -161,6 +161,7 @@ const props = withDefaults(
showSourceToggle?: boolean
startDate?: string
endDate?: string
filters?: Record<string, any>
}>(),
{
upstreamEndpointStats: () => [],
@@ -193,6 +194,7 @@ const toggleBreakdown = async (endpoint: string) => {
breakdownItems.value = []
try {
const res = await getUserBreakdown({
...props.filters,
start_date: props.startDate,
end_date: props.endDate,
endpoint,

View File

@@ -125,6 +125,7 @@ const props = withDefaults(defineProps<{
showMetricToggle?: boolean
startDate?: string
endDate?: string
filters?: Record<string, any>
}>(), {
loading: false,
metric: 'tokens',
@@ -150,6 +151,7 @@ const toggleBreakdown = async (type: string, id: number | string) => {
breakdownItems.value = []
try {
const res = await getUserBreakdown({
...props.filters,
start_date: props.startDate,
end_date: props.endDate,
group_id: Number(id),

View File

@@ -270,6 +270,7 @@ const props = withDefaults(defineProps<{
rankingError?: boolean
startDate?: string
endDate?: string
filters?: Record<string, any>
}>(), {
upstreamModelStats: () => [],
mappingModelStats: () => [],
@@ -302,6 +303,7 @@ const toggleBreakdown = async (type: string, id: string) => {
breakdownItems.value = []
try {
const res = await getUserBreakdown({
...props.filters,
start_date: props.startDate,
end_date: props.endDate,
model: id,

View File

@@ -287,6 +287,21 @@ const FolderIcon = {
)
}
const ChannelIcon = {
render: () =>
h(
'svg',
{ fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
[
h('path', {
'stroke-linecap': 'round',
'stroke-linejoin': 'round',
d: 'M6.429 9.75L2.25 12l4.179 2.25m0-4.5l5.571 3 5.571-3m-11.142 0L2.25 7.5 12 2.25l9.75 5.25-4.179 2.25m0 0l4.179 2.25L12 17.25 2.25 12m15.321-2.25l4.179 2.25L12 17.25l-9.75-5.25'
})
]
)
}
const CreditCardIcon = {
render: () =>
h(
@@ -568,6 +583,7 @@ const adminNavItems = computed((): NavItem[] => {
: []),
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
{ path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon },
{ path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon },

View File

@@ -335,6 +335,7 @@ export default {
profile: 'Profile',
users: 'Users',
groups: 'Groups',
channels: 'Channels',
subscriptions: 'Subscriptions',
accounts: 'Accounts',
proxies: 'Proxies',
@@ -1719,6 +1720,107 @@ export default {
}
},
// Channel Management
channels: {
title: 'Channel Management',
description: 'Manage channels and custom model pricing',
searchChannels: 'Search channels...',
createChannel: 'Create Channel',
editChannel: 'Edit Channel',
deleteChannel: 'Delete Channel',
statusActive: 'Active',
statusDisabled: 'Disabled',
allStatus: 'All Status',
groupsUnit: 'groups',
pricingUnit: 'pricing rules',
noChannelsYet: 'No Channels Yet',
createFirstChannel: 'Create your first channel to manage model pricing',
loadError: 'Failed to load channels',
createSuccess: 'Channel created',
updateSuccess: 'Channel updated',
deleteSuccess: 'Channel deleted',
createError: 'Failed to create channel',
updateError: 'Failed to update channel',
deleteError: 'Failed to delete channel',
nameRequired: 'Please enter a channel name',
duplicateModels: 'Model "{0}" appears in multiple pricing entries',
modelConflict: "Model patterns '{model1}' and '{model2}' conflict: overlapping match range",
mappingConflict: "Mapping source patterns '{model1}' and '{model2}' conflict: overlapping match range",
deleteConfirm: 'Are you sure you want to delete channel "{name}"? This cannot be undone.',
columns: {
name: 'Name',
description: 'Description',
status: 'Status',
groups: 'Groups',
pricing: 'Pricing',
createdAt: 'Created',
actions: 'Actions'
},
billingMode: {
token: 'Token',
perRequest: 'Per Request',
image: 'Image (Per Request)'
},
form: {
name: 'Name',
namePlaceholder: 'Enter channel name',
description: 'Description',
descriptionPlaceholder: 'Optional description',
status: 'Status',
groups: 'Associated Groups',
noGroupsAvailable: 'No groups available',
inOtherChannel: 'In "{name}"',
modelPricing: 'Model Pricing',
models: 'Models',
modelsPlaceholder: 'Type full model name and press Enter',
modelInputHint: 'Press Enter to add, supports paste for batch import.',
billingMode: 'Billing Mode',
defaultPrices: 'Default prices (fallback when no interval matches)',
inputPrice: 'Input',
outputPrice: 'Output',
cacheWritePrice: 'Cache Write',
cacheReadPrice: 'Cache Read',
imageTokenPrice: 'Image Output',
imageOutputPrice: 'Image Output Price',
pricePlaceholder: 'Default',
intervals: 'Context Intervals (optional)',
addInterval: 'Add Interval',
requestTiers: 'Request Tiers',
imageTiers: 'Image Tiers (Per Request)',
addTier: 'Add Tier',
noTiersYet: 'No tiers yet. Click add to configure per-request pricing.',
noPricingRules: 'No pricing rules yet. Click "Add" to create one.',
perRequestPrice: 'Price per Request',
perRequestPriceRequired: 'Per-request price or billing tiers required for per-request/image billing mode',
tierLabel: 'Tier',
resolution: 'Resolution',
modelMapping: 'Model Mapping',
modelMappingHint: 'Map request model names to actual model names. Runs before account-level mapping.',
noMappingRules: 'No mapping rules. Click "Add" to create one.',
mappingSource: 'Source model',
mappingTarget: 'Target model',
billingModelSource: 'Billing Model',
billingModelSourceChannelMapped: 'Bill by channel-mapped model',
billingModelSourceRequested: 'Bill by requested model',
billingModelSourceUpstream: 'Bill by final upstream model',
billingModelSourceHint: 'Controls which model name is used for pricing lookup',
selectedCount: '{count} selected',
searchGroups: 'Search groups...',
noGroupsMatch: 'No groups match your search',
restrictModels: 'Restrict Models',
restrictModelsHint: 'When enabled, only models in the pricing list are allowed. Others will be rejected.',
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
defaultImagePrice: 'Default image price (fallback when no tier matches)',
platformConfig: 'Platform Configuration',
basicSettings: 'Basic Settings',
addPlatform: 'Add Platform',
noPlatforms: 'Click "Add Platform" to start configuring the channel',
mappingCount: 'mappings',
pricingEntry: 'Pricing Entry',
noModels: 'No models added'
}
},
// Subscriptions
subscriptions: {
title: 'Subscription Management',
@@ -3258,6 +3360,11 @@ export default {
allBillingTypes: 'All Billing Types',
billingTypeBalance: 'Balance',
billingTypeSubscription: 'Subscription',
billingMode: 'Billing Mode',
billingModeToken: 'Token',
billingModePerRequest: 'Per Request',
billingModeImage: 'Image',
allBillingModes: 'All Billing Modes',
ipAddress: 'IP',
clickToViewBalance: 'Click to view balance history',
failedToLoadUser: 'Failed to load user info',

View File

@@ -335,6 +335,7 @@ export default {
profile: '个人资料',
users: '用户管理',
groups: '分组管理',
channels: '渠道管理',
subscriptions: '订阅管理',
accounts: '账号管理',
proxies: 'IP管理',
@@ -1799,6 +1800,107 @@ export default {
}
},
// Channel Management
channels: {
title: '渠道管理',
description: '管理渠道和自定义模型定价',
searchChannels: '搜索渠道...',
createChannel: '创建渠道',
editChannel: '编辑渠道',
deleteChannel: '删除渠道',
statusActive: '启用',
statusDisabled: '停用',
allStatus: '全部状态',
groupsUnit: '个分组',
pricingUnit: '条定价',
noChannelsYet: '暂无渠道',
createFirstChannel: '创建第一个渠道来管理模型定价',
loadError: '加载渠道列表失败',
createSuccess: '渠道创建成功',
updateSuccess: '渠道更新成功',
deleteSuccess: '渠道删除成功',
createError: '创建渠道失败',
updateError: '更新渠道失败',
deleteError: '删除渠道失败',
nameRequired: '请输入渠道名称',
duplicateModels: '模型「{0}」在多个定价条目中重复',
modelConflict: "模型模式 '{model1}' 和 '{model2}' 冲突:匹配范围重叠",
mappingConflict: "模型映射源 '{model1}' 和 '{model2}' 冲突:匹配范围重叠",
deleteConfirm: '确定要删除渠道「{name}」吗?此操作不可撤销。',
columns: {
name: '名称',
description: '描述',
status: '状态',
groups: '分组',
pricing: '定价',
createdAt: '创建时间',
actions: '操作'
},
billingMode: {
token: 'Token',
perRequest: '按次',
image: '图片(按次)'
},
form: {
name: '名称',
namePlaceholder: '输入渠道名称',
description: '描述',
descriptionPlaceholder: '可选描述',
status: '状态',
groups: '关联分组',
noGroupsAvailable: '暂无可用分组',
inOtherChannel: '已属于「{name}」',
modelPricing: '模型定价',
models: '模型列表',
modelsPlaceholder: '输入完整模型名后按回车添加',
modelInputHint: '按回车添加,支持粘贴批量导入',
billingMode: '计费模式',
defaultPrices: '默认价格(未命中区间时使用)',
inputPrice: '输入',
outputPrice: '输出',
cacheWritePrice: '缓存写入',
cacheReadPrice: '缓存读取',
imageTokenPrice: '图片输出',
imageOutputPrice: '图片输出价格',
pricePlaceholder: '默认',
intervals: '上下文区间定价(可选)',
addInterval: '添加区间',
requestTiers: '按次计费层级',
imageTiers: '图片计费层级(按次)',
addTier: '添加层级',
noTiersYet: '暂无层级,点击添加配置按次计费价格',
noPricingRules: '暂无定价规则,点击"添加"创建',
perRequestPrice: '单次价格',
perRequestPriceRequired: '按次/图片计费模式必须设置默认价格或至少一个计费层级',
tierLabel: '层级',
resolution: '分辨率',
modelMapping: '模型映射',
modelMappingHint: '将请求中的模型名映射为实际模型名。在账号级别映射之前执行。',
noMappingRules: '暂无映射规则,点击"添加"创建',
mappingSource: '源模型',
mappingTarget: '目标模型',
billingModelSource: '计费基准',
billingModelSourceChannelMapped: '以渠道映射后的模型计费',
billingModelSourceRequested: '以请求模型计费',
billingModelSourceUpstream: '以最终模型计费',
billingModelSourceHint: '控制使用哪个模型名称进行定价查找',
selectedCount: '已选 {count} 个',
searchGroups: '搜索分组...',
noGroupsMatch: '没有匹配的分组',
restrictModels: '限制模型',
restrictModelsHint: '开启后,仅允许模型定价列表中的模型。不在列表中的模型请求将被拒绝。',
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
defaultImagePrice: '默认图片价格(未命中层级时使用)',
platformConfig: '平台配置',
basicSettings: '基础设置',
addPlatform: '添加平台',
noPlatforms: '点击"添加平台"开始配置渠道',
mappingCount: '条映射',
pricingEntry: '定价配置',
noModels: '未添加模型'
}
},
// Subscriptions Management
subscriptions: {
title: '订阅管理',
@@ -3417,6 +3519,11 @@ export default {
allBillingTypes: '全部计费类型',
billingTypeBalance: '钱包余额',
billingTypeSubscription: '订阅套餐',
billingMode: '计费模式',
billingModeToken: '按量',
billingModePerRequest: '按次',
billingModeImage: '按次(图片)',
allBillingModes: '全部计费模式',
ipAddress: 'IP',
clickToViewBalance: '点击查看充值记录',
failedToLoadUser: '加载用户信息失败',

View File

@@ -278,6 +278,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.groups.description'
}
},
{
path: '/admin/channels',
name: 'AdminChannels',
component: () => import('@/views/admin/ChannelsView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Channel Management',
titleKey: 'admin.channels.title',
descriptionKey: 'admin.channels.description'
}
},
{
path: '/admin/subscriptions',
name: 'AdminSubscriptions',

View File

@@ -1036,6 +1036,9 @@ export interface UsageLog {
// Cache TTL Override
cache_ttl_overridden: boolean
// 计费模式
billing_mode?: string | null
created_at: string
user?: User
@@ -1051,6 +1054,7 @@ export interface UsageLogAccountSummary {
export interface AdminUsageLog extends UsageLog {
upstream_model?: string | null
model_mapping_chain?: string | null
// 账号计费倍率(仅管理员可见)
account_rate_multiplier?: number | null

View File

@@ -0,0 +1,18 @@
/**
* 格式化缓存 token 数量1K/1M 缩写)
*/
export function formatCacheTokens(tokens: number): string {
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
return tokens.toLocaleString()
}
/**
* 自适应精度格式化倍率(确保小数值如 0.001 不被截断)
*/
export function formatMultiplier(val: number): string {
if (val >= 0.01) return val.toFixed(2)
if (val >= 0.001) return val.toFixed(3)
if (val >= 0.0001) return val.toFixed(4)
return val.toPrecision(2)
}

File diff suppressed because it is too large Load Diff

View File

@@ -34,6 +34,7 @@
:show-metric-toggle="true"
:start-date="startDate"
:end-date="endDate"
:filters="breakdownFilters"
/>
<GroupDistributionChart
v-model:metric="groupDistributionMetric"
@@ -42,6 +43,7 @@
:show-metric-toggle="true"
:start-date="startDate"
:end-date="endDate"
:filters="breakdownFilters"
/>
</div>
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
@@ -57,6 +59,7 @@
:title="t('usage.endpointDistribution')"
:start-date="startDate"
:end-date="endDate"
:filters="breakdownFilters"
/>
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
</div>
@@ -169,6 +172,17 @@ const cleanupDialogVisible = ref(false)
const showBalanceHistoryModal = ref(false)
const balanceHistoryUser = ref<AdminUser | null>(null)
const breakdownFilters = computed(() => {
const f: Record<string, any> = {}
if (filters.value.user_id) f.user_id = filters.value.user_id
if (filters.value.api_key_id) f.api_key_id = filters.value.api_key_id
if (filters.value.account_id) f.account_id = filters.value.account_id
if (filters.value.group_id) f.group_id = filters.value.group_id
if (filters.value.request_type != null) f.request_type = filters.value.request_type
if (filters.value.billing_type != null) f.billing_type = filters.value.billing_type
return f
})
const handleUserClick = async (userId: number) => {
try {
const user = await adminAPI.users.getById(userId)
@@ -392,7 +406,7 @@ const resetFilters = () => {
const range = getLast24HoursRangeDates()
startDate.value = range.start
endDate.value = range.end
filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }
filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null, billing_mode: undefined }
granularity.value = getGranularityForRange(startDate.value, endDate.value)
applyFilters()
}
@@ -440,7 +454,7 @@ const exportToExcel = async () => {
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
log.rate_multiplier?.toFixed(2) || '1.00', (log.account_rate_multiplier ?? 1).toFixed(2),
log.rate_multiplier?.toPrecision(4) || '1.00', (log.account_rate_multiplier ?? 1).toPrecision(4),
log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms,
log.request_id || '', log.user_agent || '', log.ip_address || ''
@@ -477,6 +491,7 @@ const allColumns = computed(() => [
{ key: 'endpoint', label: t('usage.endpoint'), sortable: false },
{ key: 'group', label: t('admin.usage.group'), sortable: false },
{ key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'billing_mode', label: t('admin.usage.billingMode'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
{ key: 'cost', label: t('usage.cost'), sortable: false },
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },

Some files were not shown because too many files have changed in this diff Show More