mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 13:40:44 +08:00
feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计
- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离 - 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价 - 模型限制:渠道可限制仅允许定价列表中的模型 - 计费模型来源:支持 requested/upstream 两种计费模型选择 - 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段 - Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计 - 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict - 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
This commit is contained in:
@@ -51,15 +51,15 @@ type Channel struct {
|
||||
type ChannelModelPricing struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Platform string // 所属平台(anthropic/openai/gemini/...)
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
PerRequestPrice *float64 // 默认按次计费价格(USD)
|
||||
Platform string // 所属平台(anthropic/openai/gemini/...)
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
PerRequestPrice *float64 // 默认按次计费价格(USD)
|
||||
Intervals []PricingInterval // 区间定价列表
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
@@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel {
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
|
||||
type ChannelUsageFields struct {
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -17,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrGroupAlreadyInChannel = infraerrors.Conflict(
|
||||
"GROUP_ALREADY_IN_CHANNEL",
|
||||
"one or more groups already belong to another channel",
|
||||
@@ -81,12 +80,12 @@ type wildcardMappingEntry struct {
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
@@ -118,9 +117,19 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
|
||||
return reqModel + "→" + r.MappedModel
|
||||
}
|
||||
|
||||
// ToUsageFields 将渠道映射结果转为使用记录字段
|
||||
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
|
||||
return ChannelUsageFields{
|
||||
ChannelID: r.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: r.BillingModelSource,
|
||||
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
@@ -177,14 +186,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
@@ -205,14 +214,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
|
||||
for gpKey, entries := range cache.wildcardByGroupPlatform {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardByGroupPlatform[gpKey] = entries
|
||||
}
|
||||
for gpKey, entries := range cache.wildcardMappingByGP {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardMappingByGP[gpKey] = entries
|
||||
}
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() {
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
|
||||
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardByGroupPlatform[gpKey]
|
||||
@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先)
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardMappingByGP[gpKey]
|
||||
@@ -479,15 +476,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceRequested
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复)
|
||||
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
|
||||
seen := make(map[string]bool)
|
||||
// modelEntry 表示一个模型模式条目(用于冲突检测)
|
||||
type modelEntry struct {
|
||||
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4")
|
||||
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
|
||||
wildcard bool
|
||||
}
|
||||
|
||||
// conflictsBetween 检查两个模型模式是否冲突
|
||||
func conflictsBetween(a, b modelEntry) bool {
|
||||
switch {
|
||||
case !a.wildcard && !b.wildcard:
|
||||
return a.prefix == b.prefix
|
||||
case a.wildcard && !b.wildcard:
|
||||
return strings.HasPrefix(b.prefix, a.prefix)
|
||||
case !a.wildcard && b.wildcard:
|
||||
return strings.HasPrefix(a.prefix, b.prefix)
|
||||
default:
|
||||
return strings.HasPrefix(a.prefix, b.prefix) ||
|
||||
strings.HasPrefix(b.prefix, a.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// toModelEntry 将模型名转换为 modelEntry
|
||||
func toModelEntry(pattern string) modelEntry {
|
||||
lower := strings.ToLower(pattern)
|
||||
isWild := strings.HasSuffix(lower, "*")
|
||||
prefix := lower
|
||||
if isWild {
|
||||
prefix = strings.TrimSuffix(lower, "*")
|
||||
}
|
||||
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
|
||||
}
|
||||
|
||||
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
|
||||
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
|
||||
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
|
||||
byPlatform := make(map[string][]modelEntry)
|
||||
for _, p := range pricingList {
|
||||
for _, model := range p.Models {
|
||||
key := p.Platform + ":" + strings.ToLower(model)
|
||||
if seen[key] {
|
||||
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
|
||||
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
|
||||
}
|
||||
}
|
||||
for platform, entries := range byPlatform {
|
||||
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
|
||||
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
|
||||
for platform, platformMapping := range mapping {
|
||||
entries := make([]modelEntry, 0, len(platformMapping))
|
||||
for src := range platformMapping {
|
||||
entries = append(entries, toModelEntry(src))
|
||||
}
|
||||
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
|
||||
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
|
||||
for i := 0; i < len(entries); i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if conflictsBetween(entries[i], entries[j]) {
|
||||
return infraerrors.BadRequest(errCode,
|
||||
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
|
||||
label, entries[i].pattern, entries[j].pattern, platform))
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
1890
backend/internal/service/channel_service_test.go
Normal file
1890
backend/internal/service/channel_service_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,13 +8,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func channelTestPtrFloat64(v float64) *float64 { return &v }
|
||||
func channelTestPtrInt(v int) *int { return &v }
|
||||
|
||||
func TestGetModelPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
|
||||
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
|
||||
},
|
||||
}
|
||||
@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) {
|
||||
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -73,23 +70,23 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) {
|
||||
func TestGetIntervalForContext(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
}{
|
||||
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
|
||||
{"first interval", 50000, testPtrFloat64(1e-6), false},
|
||||
// (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个
|
||||
{"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false},
|
||||
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
|
||||
// 128001 > 128000,匹配第二个区间
|
||||
{"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
|
||||
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
|
||||
// (0, max] — 0 不匹配任何区间(左开)
|
||||
{"zero tokens: no match", 0, nil, true},
|
||||
}
|
||||
@@ -110,11 +107,11 @@ func TestGetIntervalForContext(t *testing.T) {
|
||||
func TestGetIntervalForContext_NoMatch(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
|
||||
},
|
||||
}
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
|
||||
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
|
||||
}
|
||||
@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) {
|
||||
func TestGetTierByLabel(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) {
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"model-a"},
|
||||
InputPrice: channelTestPtrFloat64(5e-6),
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) {
|
||||
cloned.Intervals[0].TierLabel = "hacked"
|
||||
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
|
||||
}
|
||||
|
||||
// --- BillingMode.IsValid ---
|
||||
|
||||
func TestBillingModeIsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode BillingMode
|
||||
want bool
|
||||
}{
|
||||
{"token", BillingModeToken, true},
|
||||
{"per_request", BillingModePerRequest, true},
|
||||
{"image", BillingModeImage, true},
|
||||
{"empty", BillingMode(""), true},
|
||||
{"unknown", BillingMode("unknown"), false},
|
||||
{"random", BillingMode("xyz"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.mode.IsValid())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Channel.IsActive ---
|
||||
|
||||
func TestChannelIsActive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
want bool
|
||||
}{
|
||||
{"active", StatusActive, true},
|
||||
{"disabled", "disabled", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := &Channel{Status: tt.status}
|
||||
require.Equal(t, tt.want, ch.IsActive())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ChannelModelPricing.Clone edge cases ---
|
||||
|
||||
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Models)
|
||||
})
|
||||
|
||||
t.Run("nil intervals", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Intervals: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Intervals)
|
||||
})
|
||||
|
||||
t.Run("empty models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: []string{}}
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned.Models)
|
||||
require.Empty(t, cloned.Models)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Channel.Clone edge cases ---
|
||||
|
||||
func TestChannelClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil model mapping", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelMapping: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelMapping)
|
||||
})
|
||||
|
||||
t.Run("nil model pricing", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelPricing: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelPricing)
|
||||
})
|
||||
|
||||
t.Run("deep copy model mapping", func(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
}
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify the cloned nested map
|
||||
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
|
||||
|
||||
// Original must remain unchanged
|
||||
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7407,11 +7407,7 @@ type RecordUsageInput struct {
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
|
||||
// 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
@@ -7940,11 +7936,7 @@ type RecordUsageLongContextInput struct {
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
|
||||
// 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
|
||||
@@ -4,14 +4,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resolverPtrFloat64(v float64) *float64 { return &v }
|
||||
func resolverPtrInt(v int) *int { return &v }
|
||||
|
||||
func newTestBillingServiceForResolver() *BillingService {
|
||||
bs := &BillingService{
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
@@ -83,8 +81,8 @@ func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
|
||||
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
|
||||
SupportsCacheBreakdown: true,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,7 +106,7 @@ func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)},
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -123,8 +121,8 @@ func TestGetRequestTierPrice(t *testing.T) {
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)},
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -140,8 +138,8 @@ func TestGetRequestTierPriceByContext(t *testing.T) {
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -162,3 +160,428 @@ func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
|
||||
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Channel override tests — exercises applyChannelOverrides via Resolve
|
||||
// ===========================================================================
|
||||
|
||||
// helper: creates a resolver wired to a ChannelService that returns the given
|
||||
// channel (active, groupID=100, platform=anthropic) with the specified pricing.
|
||||
func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver {
|
||||
t.Helper()
|
||||
const groupID = 100
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||||
return []Channel{{
|
||||
ID: 1,
|
||||
Name: "test-channel",
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelPricing: pricing,
|
||||
}}, nil
|
||||
},
|
||||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||||
return map[int64]string{groupID: "anthropic"}, nil
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
return NewModelPricingResolver(cs, bs)
|
||||
}
|
||||
|
||||
// groupIDPtr returns a pointer to groupID 100 (the test constant).
|
||||
func groupIDPtr() *int64 { v := int64(100); return &v }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Token mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(10e-6),
|
||||
OutputPrice: testPtrFloat64(50e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) {
|
||||
// Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback).
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(20e-6),
|
||||
// OutputPrice intentionally nil
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
// InputPrice overridden by channel
|
||||
require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
// OutputPrice kept from base (fallback: 15e-6)
|
||||
require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.Len(t, resolved.Intervals, 2)
|
||||
|
||||
// GetIntervalPricing should use channel intervals
|
||||
iv := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, iv)
|
||||
require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12)
|
||||
|
||||
iv2 := r.GetIntervalPricing(resolved, 200000)
|
||||
require.NotNil(t, iv2)
|
||||
require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) {
|
||||
// Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"unknown-model-xyz"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(7e-6),
|
||||
OutputPrice: testPtrFloat64(21e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "unknown-model-xyz",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
// BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Per-request mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_PerRequest(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModePerRequest, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 2)
|
||||
|
||||
// Verify tier lookups
|
||||
require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
||||
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) {
|
||||
// PerRequestPrice nil → DefaultPerRequestPrice stays 0.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModePerRequest,
|
||||
// PerRequestPrice intentionally nil
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModePerRequest, resolved.Mode)
|
||||
require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. Image mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_Image(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.08),
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeImage, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 3)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeImage,
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
||||
require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. Source tracking & default mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(1e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) {
|
||||
// Channel pricing with empty BillingMode → defaults to BillingModeToken.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: "", // intentionally empty
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 5. GetIntervalPricing integration after channel override
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) {
|
||||
// Channel provides intervals that override the base pricing path.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)},
|
||||
{MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
// Token count 50000 matches first interval
|
||||
pricing := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
|
||||
// Token count 150000 matches second interval
|
||||
pricing2 := r.GetIntervalPricing(resolved, 150000)
|
||||
require.NotNil(t, pricing2)
|
||||
require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) {
|
||||
// Channel intervals don't match token count → falls back to BasePricing.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
// Only covers tokens > 50000
|
||||
{MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
// Token count 1000 doesn't match any interval (1000 <= 50000 minTokens)
|
||||
pricing := r.GetIntervalPricing(resolved, 1000)
|
||||
// Should fall back to BasePricing (from the billing service fallback)
|
||||
require.NotNil(t, pricing)
|
||||
require.Equal(t, resolved.BasePricing, pricing)
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 6. Error path tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
|
||||
// When ListAll returns an error, the ChannelService cache build fails.
|
||||
// Resolve should gracefully fall back to base pricing without panicking.
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||||
return nil, errors.New("database unavailable")
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(cs, bs)
|
||||
|
||||
gid := int64(100)
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: &gid,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
// Should NOT panic, should NOT have source "channel"
|
||||
require.NotEqual(t, "channel", resolved.Source)
|
||||
// Base pricing should still be present (from BillingService fallback)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 7. GetRequestTierPriceByContext boundary tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: nil, // empty
|
||||
}
|
||||
|
||||
price := r.GetRequestTierPriceByContext(resolved, 50000)
|
||||
require.InDelta(t, 0.0, price, 1e-12)
|
||||
|
||||
// Also test with explicit empty slice
|
||||
resolved2 := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{},
|
||||
}
|
||||
|
||||
price2 := r.GetRequestTierPriceByContext(resolved2, 50000)
|
||||
require.InDelta(t, 0.0, price2, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
// totalContextTokens = 128000 exactly:
|
||||
// FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens
|
||||
// For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval
|
||||
price := r.GetRequestTierPriceByContext(resolved, 128000)
|
||||
require.InDelta(t, 0.05, price, 1e-12)
|
||||
|
||||
// totalContextTokens = 128001 should match second interval
|
||||
// For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match
|
||||
// For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches
|
||||
price2 := r.GetRequestTierPriceByContext(resolved, 128001)
|
||||
require.InDelta(t, 0.10, price2, 1e-12)
|
||||
}
|
||||
|
||||
@@ -4146,10 +4146,7 @@ type OpenAIRecordUsageInput struct {
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
ChannelID int64
|
||||
OriginalModel string
|
||||
BillingModelSource string
|
||||
ModelMappingChain string
|
||||
ChannelUsageFields
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
|
||||
15
backend/internal/service/testhelpers_test.go
Normal file
15
backend/internal/service/testhelpers_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
// testPtrFloat64 returns a pointer to the given float64 value.
|
||||
func testPtrFloat64(v float64) *float64 { return &v }
|
||||
|
||||
// testPtrInt returns a pointer to the given int value.
|
||||
func testPtrInt(v int) *int { return &v }
|
||||
|
||||
// testPtrString returns a pointer to the given string value.
|
||||
func testPtrString(v string) *string { return &v }
|
||||
|
||||
// testPtrBool returns a pointer to the given bool value.
|
||||
func testPtrBool(v bool) *bool { return &v }
|
||||
Reference in New Issue
Block a user