feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计

- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离
- 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价
- 模型限制:渠道可限制仅允许定价列表中的模型
- 计费模型来源:支持 requested/upstream 两种计费模型选择
- 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段
- Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计
- 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict
- 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
This commit is contained in:
erio
2026-04-01 01:51:19 +08:00
parent 669bff78c4
commit 2555951be4
33 changed files with 3633 additions and 262 deletions

View File

@@ -51,15 +51,15 @@ type Channel struct {
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
@@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel {
}
return &cp
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"sync/atomic"
"time"
@@ -17,8 +16,8 @@ import (
)
var (
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
ErrGroupAlreadyInChannel = infraerrors.Conflict(
"GROUP_ALREADY_IN_CHANNEL",
"one or more groups already belong to another channel",
@@ -81,12 +80,12 @@ type wildcardMappingEntry struct {
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径CRUD 操作)
byID map[int64]*Channel
@@ -118,9 +117,19 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
return reqModel + "→" + r.MappedModel
}
// ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
return ChannelUsageFields{
ChannelID: r.ChannelID,
OriginalModel: reqModel,
BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
}
}
const (
channelCacheTTL = 60 * time.Second
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheTTL = 60 * time.Second
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second
)
@@ -177,14 +186,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
}
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
@@ -205,14 +214,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
cache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
}
for i := range channels {
@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
}
// 通配符条目按前缀长度降序排列(最长前缀优先匹配
for gpKey, entries := range cache.wildcardByGroupPlatform {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardByGroupPlatform[gpKey] = entries
}
for gpKey, entries := range cache.wildcardMappingByGP {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardMappingByGP[gpKey] = entries
}
// 通配符条目保持配置顺序(最先匹配到优先
s.cache.Store(cache)
return cache, nil
@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache")
}
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey]
@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
return nil
}
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先)
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardMappingByGP[gpKey]
@@ -479,15 +476,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
return s.repo.List(ctx, params, status, search)
}
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
seen := make(map[string]bool)
// modelEntry 表示一个模型模式条目(用于冲突检测
type modelEntry struct {
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4"
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
wildcard bool
}
// conflictsBetween 检查两个模型模式是否冲突
func conflictsBetween(a, b modelEntry) bool {
switch {
case !a.wildcard && !b.wildcard:
return a.prefix == b.prefix
case a.wildcard && !b.wildcard:
return strings.HasPrefix(b.prefix, a.prefix)
case !a.wildcard && b.wildcard:
return strings.HasPrefix(a.prefix, b.prefix)
default:
return strings.HasPrefix(a.prefix, b.prefix) ||
strings.HasPrefix(b.prefix, a.prefix)
}
}
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
lower := strings.ToLower(pattern)
isWild := strings.HasSuffix(lower, "*")
prefix := lower
if isWild {
prefix = strings.TrimSuffix(lower, "*")
}
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
byPlatform := make(map[string][]modelEntry)
for _, p := range pricingList {
for _, model := range p.Models {
key := p.Platform + ":" + strings.ToLower(model)
if seen[key] {
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
}
}
for platform, entries := range byPlatform {
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
return err
}
}
return nil
}
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
for platform, platformMapping := range mapping {
entries := make([]modelEntry, 0, len(platformMapping))
for src := range platformMapping {
entries = append(entries, toModelEntry(src))
}
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
return err
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if conflictsBetween(entries[i], entries[j]) {
return infraerrors.BadRequest(errCode,
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
label, entries[i].pattern, entries[j].pattern, platform))
}
seen[key] = true
}
}
return nil

File diff suppressed because it is too large Load Diff

View File

@@ -8,13 +8,10 @@ import (
"github.com/stretchr/testify/require"
)
func channelTestPtrFloat64(v float64) *float64 { return &v }
func channelTestPtrInt(v int) *int { return &v }
func TestGetModelPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
},
}
@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) {
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
},
}
@@ -73,23 +70,23 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) {
func TestGetIntervalForContext(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
}
tests := []struct {
name string
tokens int
wantPrice *float64
wantNil bool
name string
tokens int
wantPrice *float64
wantNil bool
}{
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
{"first interval", 50000, testPtrFloat64(1e-6), false},
// (min, max] — 128000 在第一个区间的 max包含所以匹配第一个
{"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false},
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
// 128001 > 128000匹配第二个区间
{"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false},
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
// (0, max] — 0 不匹配任何区间(左开)
{"zero tokens: no match", 0, nil, true},
}
@@ -110,11 +107,11 @@ func TestGetIntervalForContext(t *testing.T) {
func TestGetIntervalForContext_NoMatch(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
},
}
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
}
@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) {
func TestGetTierByLabel(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
},
}
@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) {
{
ID: 100,
Models: []string{"model-a"},
InputPrice: channelTestPtrFloat64(5e-6),
InputPrice: testPtrFloat64(5e-6),
},
},
}
@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) {
cloned.Intervals[0].TierLabel = "hacked"
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
}
// --- BillingMode.IsValid ---
func TestBillingModeIsValid(t *testing.T) {
tests := []struct {
name string
mode BillingMode
want bool
}{
{"token", BillingModeToken, true},
{"per_request", BillingModePerRequest, true},
{"image", BillingModeImage, true},
{"empty", BillingMode(""), true},
{"unknown", BillingMode("unknown"), false},
{"random", BillingMode("xyz"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.mode.IsValid())
})
}
}
// --- Channel.IsActive ---
func TestChannelIsActive(t *testing.T) {
tests := []struct {
name string
status string
want bool
}{
{"active", StatusActive, true},
{"disabled", "disabled", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := &Channel{Status: tt.status}
require.Equal(t, tt.want, ch.IsActive())
})
}
}
// --- ChannelModelPricing.Clone edge cases ---
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
t.Run("nil models", func(t *testing.T) {
original := ChannelModelPricing{Models: nil}
cloned := original.Clone()
require.Nil(t, cloned.Models)
})
t.Run("nil intervals", func(t *testing.T) {
original := ChannelModelPricing{Intervals: nil}
cloned := original.Clone()
require.Nil(t, cloned.Intervals)
})
t.Run("empty models", func(t *testing.T) {
original := ChannelModelPricing{Models: []string{}}
cloned := original.Clone()
require.NotNil(t, cloned.Models)
require.Empty(t, cloned.Models)
})
}
// --- Channel.Clone edge cases ---
func TestChannelClone_EdgeCases(t *testing.T) {
t.Run("nil model mapping", func(t *testing.T) {
original := &Channel{ID: 1, ModelMapping: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelMapping)
})
t.Run("nil model pricing", func(t *testing.T) {
original := &Channel{ID: 1, ModelPricing: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelPricing)
})
t.Run("deep copy model mapping", func(t *testing.T) {
original := &Channel{
ID: 1,
ModelMapping: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
}
cloned := original.Clone()
// Modify the cloned nested map
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
// Original must remain unchanged
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
})
}

View File

@@ -7407,11 +7407,7 @@ type RecordUsageInput struct {
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
// 渠道映射信息(由 handler 在 Forward 前解析)
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
@@ -7940,11 +7936,7 @@ type RecordUsageLongContextInput struct {
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
// 渠道映射信息(由 handler 在 Forward 前解析)
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini

View File

@@ -4,14 +4,12 @@ package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func resolverPtrFloat64(v float64) *float64 { return &v }
func resolverPtrInt(v int) *int { return &v }
func newTestBillingServiceForResolver() *BillingService {
bs := &BillingService{
fallbackPrices: make(map[string]*ModelPricing),
@@ -83,8 +81,8 @@ func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
SupportsCacheBreakdown: true,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)},
},
}
@@ -108,7 +106,7 @@ func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
Mode: BillingModeToken,
BasePricing: basePricing,
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)},
{MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)},
},
}
@@ -123,8 +121,8 @@ func TestGetRequestTierPrice(t *testing.T) {
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)},
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
},
}
@@ -140,8 +138,8 @@ func TestGetRequestTierPriceByContext(t *testing.T) {
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}
@@ -162,3 +160,428 @@ func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
}
// ===========================================================================
// Channel override tests — exercises applyChannelOverrides via Resolve
// ===========================================================================
// helper: creates a resolver wired to a ChannelService that returns the given
// channel (active, groupID=100, platform=anthropic) with the specified pricing.
func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver {
t.Helper()
const groupID = 100
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return []Channel{{
ID: 1,
Name: "test-channel",
Status: StatusActive,
GroupIDs: []int64{groupID},
ModelPricing: pricing,
}}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return map[int64]string{groupID: "anthropic"}, nil
},
}
cs := NewChannelService(repo, nil)
bs := newTestBillingServiceForResolver()
return NewModelPricingResolver(cs, bs)
}
// groupIDPtr returns a pointer to groupID 100 (the test constant).
func groupIDPtr() *int64 { v := int64(100); return &v }
// ---------------------------------------------------------------------------
// 1. Token mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(50e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeToken, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12)
}
func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) {
// Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback).
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(20e-6),
// OutputPrice intentionally nil
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
require.NotNil(t, resolved.BasePricing)
// InputPrice overridden by channel
require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
// OutputPrice kept from base (fallback: 15e-6)
require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
}
func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
require.Len(t, resolved.Intervals, 2)
// GetIntervalPricing should use channel intervals
iv := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, iv)
require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12)
require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12)
iv2 := r.GetIntervalPricing(resolved, 200000)
require.NotNil(t, iv2)
require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12)
require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12)
}
func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) {
// Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"unknown-model-xyz"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(7e-6),
OutputPrice: testPtrFloat64(21e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "unknown-model-xyz",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
// BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
}
// ---------------------------------------------------------------------------
// 2. Per-request mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_PerRequest(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModePerRequest, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 2)
// Verify tier lookups
require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
}
func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) {
// PerRequestPrice nil → DefaultPerRequestPrice stays 0.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModePerRequest,
// PerRequestPrice intentionally nil
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModePerRequest, resolved.Mode)
require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 1)
}
// ---------------------------------------------------------------------------
// 3. Image mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_Image(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.08),
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeImage, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 3)
}
func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeImage,
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found
}
// ---------------------------------------------------------------------------
// 4. Source tracking & default mode
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(1e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.Equal(t, "channel", resolved.Source)
}
func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) {
// Channel pricing with empty BillingMode → defaults to BillingModeToken.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: "", // intentionally empty
InputPrice: testPtrFloat64(5e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.Equal(t, "channel", resolved.Source)
require.Equal(t, BillingModeToken, resolved.Mode)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
}
// ---------------------------------------------------------------------------
// 5. GetIntervalPricing integration after channel override
// ---------------------------------------------------------------------------
func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) {
// Channel provides intervals that override the base pricing path.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)},
{MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
// Token count 50000 matches first interval
pricing := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, pricing)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
// Token count 150000 matches second interval
pricing2 := r.GetIntervalPricing(resolved, 150000)
require.NotNil(t, pricing2)
require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12)
}
func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) {
// Channel intervals don't match token count → falls back to BasePricing.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
// Only covers tokens > 50000
{MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
// Token count 1000 doesn't match any interval (1000 <= 50000 minTokens)
pricing := r.GetIntervalPricing(resolved, 1000)
// Should fall back to BasePricing (from the billing service fallback)
require.NotNil(t, pricing)
require.Equal(t, resolved.BasePricing, pricing)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price
}
// ===========================================================================
// 6. Error path tests
// ===========================================================================
func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
// When ListAll returns an error, the ChannelService cache build fails.
// Resolve should gracefully fall back to base pricing without panicking.
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, errors.New("database unavailable")
},
}
cs := NewChannelService(repo, nil)
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(cs, bs)
gid := int64(100)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: &gid,
})
require.NotNil(t, resolved)
// Should NOT panic, should NOT have source "channel"
require.NotEqual(t, "channel", resolved.Source)
// Base pricing should still be present (from BillingService fallback)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
}
// ===========================================================================
// 7. GetRequestTierPriceByContext boundary tests
// ===========================================================================
func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: nil, // empty
}
price := r.GetRequestTierPriceByContext(resolved, 50000)
require.InDelta(t, 0.0, price, 1e-12)
// Also test with explicit empty slice
resolved2 := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{},
}
price2 := r.GetRequestTierPriceByContext(resolved2, 50000)
require.InDelta(t, 0.0, price2, 1e-12)
}
func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}
// totalContextTokens = 128000 exactly:
// FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens
// For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval
price := r.GetRequestTierPriceByContext(resolved, 128000)
require.InDelta(t, 0.05, price, 1e-12)
// totalContextTokens = 128001 should match second interval
// For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match
// For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches
price2 := r.GetRequestTierPriceByContext(resolved, 128001)
require.InDelta(t, 0.10, price2, 1e-12)
}

View File

@@ -4146,10 +4146,7 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
ChannelID int64
OriginalModel string
BillingModelSource string
ModelMappingChain string
ChannelUsageFields
}
// RecordUsage records usage and deducts balance

View File

@@ -0,0 +1,15 @@
//go:build unit
package service
// testPtrFloat64 returns a pointer to the given float64 value.
func testPtrFloat64(v float64) *float64 { return &v }
// testPtrInt returns a pointer to the given int value.
func testPtrInt(v int) *int { return &v }
// testPtrString returns a pointer to the given string value.
func testPtrString(v string) *string { return &v }
// testPtrBool returns a pointer to the given bool value.
func testPtrBool(v bool) *bool { return &v }