2026-04-14 08:42:28 +08:00
|
|
|
|
//go:build unit
|
|
|
|
|
|
|
|
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"testing"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
// CalculateCostUnified
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
|
|
|
|
|
|
svc := newTestBillingService()
|
|
|
|
|
|
|
|
|
|
|
|
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
|
|
|
|
|
input := CostInput{
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
Tokens: tokens,
|
|
|
|
|
|
RateMultiplier: 1.0,
|
|
|
|
|
|
Resolver: nil, // no resolver
|
|
|
|
|
|
}
|
|
|
|
|
|
cost, err := svc.CalculateCostUnified(input)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
|
|
// Should match the old-path result exactly
|
|
|
|
|
|
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
|
|
|
|
|
|
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
|
|
|
|
|
|
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
|
|
|
|
|
|
require.Empty(t, cost.BillingMode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestCalculateCostUnified_TokenMode(t *testing.T) {
|
|
|
|
|
|
bs := newTestBillingService()
|
|
|
|
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
|
|
|
|
|
|
|
|
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
|
|
|
|
|
input := CostInput{
|
|
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
Tokens: tokens,
|
|
|
|
|
|
RateMultiplier: 1.5,
|
|
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
}
|
|
|
|
|
|
cost, err := bs.CalculateCostUnified(input)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, cost)
|
|
|
|
|
|
|
|
|
|
|
|
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
|
|
|
|
|
|
expectedTotal := 1000*3e-6 + 500*15e-6
|
|
|
|
|
|
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
|
|
|
|
|
|
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
|
|
|
|
|
|
require.Equal(t, string(BillingModeToken), cost.BillingMode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
|
|
|
|
|
|
// Set up a ChannelService with a per-request pricing channel
|
|
|
|
|
|
cs := newTestChannelServiceWithCache(t, &channelCache{
|
|
|
|
|
|
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
|
|
|
|
|
{groupID: 1, model: "claude-sonnet-4"}: {
|
|
|
|
|
|
BillingMode: BillingModePerRequest,
|
|
|
|
|
|
PerRequestPrice: testPtrFloat64(0.05),
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
channelByGroupID: map[int64]*Channel{
|
|
|
|
|
|
1: {ID: 1, Status: StatusActive},
|
|
|
|
|
|
},
|
|
|
|
|
|
groupPlatform: map[int64]string{1: ""},
|
|
|
|
|
|
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
|
|
|
|
|
mappingByGroupModel: map[channelModelKey]string{},
|
|
|
|
|
|
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
|
|
|
|
|
byID: map[int64]*Channel{},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
bs := newTestBillingService()
|
|
|
|
|
|
resolver := NewModelPricingResolver(cs, bs)
|
|
|
|
|
|
groupID := int64(1)
|
|
|
|
|
|
|
|
|
|
|
|
input := CostInput{
|
|
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
GroupID: &groupID,
|
|
|
|
|
|
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
|
|
|
|
|
|
RequestCount: 3,
|
|
|
|
|
|
RateMultiplier: 2.0,
|
|
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
}
|
|
|
|
|
|
cost, err := bs.CalculateCostUnified(input)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, cost)
|
|
|
|
|
|
|
|
|
|
|
|
// 3 requests * $0.05 = $0.15
|
|
|
|
|
|
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
|
|
|
|
|
|
// ActualCost = 0.15 * 2.0 = 0.30
|
|
|
|
|
|
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
|
|
|
|
|
|
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestCalculateCostUnified_ImageMode(t *testing.T) {
|
|
|
|
|
|
cs := newTestChannelServiceWithCache(t, &channelCache{
|
|
|
|
|
|
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
|
|
|
|
|
{groupID: 2, model: "gemini-image"}: {
|
|
|
|
|
|
BillingMode: BillingModeImage,
|
|
|
|
|
|
PerRequestPrice: testPtrFloat64(0.10),
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
channelByGroupID: map[int64]*Channel{
|
|
|
|
|
|
2: {ID: 2, Status: StatusActive},
|
|
|
|
|
|
},
|
|
|
|
|
|
groupPlatform: map[int64]string{2: ""},
|
|
|
|
|
|
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
|
|
|
|
|
mappingByGroupModel: map[channelModelKey]string{},
|
|
|
|
|
|
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
|
|
|
|
|
byID: map[int64]*Channel{},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
bs := &BillingService{
|
|
|
|
|
|
cfg: &config.Config{},
|
|
|
|
|
|
fallbackPrices: map[string]*ModelPricing{},
|
|
|
|
|
|
}
|
|
|
|
|
|
resolver := NewModelPricingResolver(cs, bs)
|
|
|
|
|
|
groupID := int64(2)
|
|
|
|
|
|
|
|
|
|
|
|
input := CostInput{
|
|
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "gemini-image",
|
|
|
|
|
|
GroupID: &groupID,
|
|
|
|
|
|
Tokens: UsageTokens{},
|
|
|
|
|
|
RequestCount: 2,
|
|
|
|
|
|
RateMultiplier: 1.0,
|
|
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
}
|
|
|
|
|
|
cost, err := bs.CalculateCostUnified(input)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, cost)
|
|
|
|
|
|
|
|
|
|
|
|
// 2 * $0.10 = $0.20
|
|
|
|
|
|
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
|
|
|
|
|
|
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
|
|
|
|
|
|
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-17 18:32:12 +08:00
|
|
|
|
// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
|
|
|
|
|
|
// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
|
|
|
|
|
|
func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
|
2026-04-14 08:42:28 +08:00
|
|
|
|
bs := newTestBillingService()
|
|
|
|
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
|
|
|
|
|
|
|
|
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
|
|
|
|
|
|
2026-04-17 18:32:12 +08:00
|
|
|
|
cost, err := bs.CalculateCostUnified(CostInput{
|
2026-04-14 08:42:28 +08:00
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
Tokens: tokens,
|
2026-04-17 18:32:12 +08:00
|
|
|
|
RateMultiplier: 0,
|
2026-04-14 08:42:28 +08:00
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
})
|
|
|
|
|
|
require.NoError(t, err)
|
2026-04-17 18:32:12 +08:00
|
|
|
|
require.Greater(t, cost.TotalCost, 0.0)
|
|
|
|
|
|
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
2026-04-14 08:42:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-17 18:32:12 +08:00
|
|
|
|
// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
|
|
|
|
|
|
// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
|
|
|
|
|
|
func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
|
2026-04-14 08:42:28 +08:00
|
|
|
|
bs := newTestBillingService()
|
|
|
|
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
|
|
|
|
|
|
|
|
tokens := UsageTokens{InputTokens: 1000}
|
|
|
|
|
|
|
2026-04-17 18:32:12 +08:00
|
|
|
|
cost, err := bs.CalculateCostUnified(CostInput{
|
2026-04-14 08:42:28 +08:00
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
Tokens: tokens,
|
|
|
|
|
|
RateMultiplier: -5.0,
|
|
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
})
|
|
|
|
|
|
require.NoError(t, err)
|
2026-04-17 18:32:12 +08:00
|
|
|
|
require.Greater(t, cost.TotalCost, 0.0)
|
|
|
|
|
|
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
2026-04-14 08:42:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
|
|
|
|
|
|
bs := newTestBillingService()
|
|
|
|
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
|
|
|
|
|
|
|
|
cost, err := bs.CalculateCostUnified(CostInput{
|
|
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
Tokens: UsageTokens{InputTokens: 100},
|
|
|
|
|
|
RateMultiplier: 1.0,
|
|
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
})
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.Equal(t, "token", cost.BillingMode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
|
|
|
|
|
|
bs := newTestBillingService()
|
|
|
|
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
|
|
|
|
|
|
|
|
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
|
|
|
|
|
|
preResolved := &ResolvedPricing{
|
|
|
|
|
|
Mode: BillingModePerRequest,
|
|
|
|
|
|
DefaultPerRequestPrice: 0.07,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
cost, err := bs.CalculateCostUnified(CostInput{
|
|
|
|
|
|
Ctx: context.Background(),
|
|
|
|
|
|
Model: "claude-sonnet-4",
|
|
|
|
|
|
Tokens: UsageTokens{InputTokens: 100},
|
|
|
|
|
|
RequestCount: 2,
|
|
|
|
|
|
RateMultiplier: 1.0,
|
|
|
|
|
|
Resolver: resolver,
|
|
|
|
|
|
Resolved: preResolved,
|
|
|
|
|
|
})
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
require.NotNil(t, cost)
|
|
|
|
|
|
|
|
|
|
|
|
// 2 * $0.07 = $0.14
|
|
|
|
|
|
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
|
|
|
|
|
|
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
// helpers
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
|
|
|
|
|
|
// cache snapshot, bypassing the repository layer entirely.
|
|
|
|
|
|
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
|
|
|
|
|
|
t.Helper()
|
|
|
|
|
|
cs := &ChannelService{}
|
|
|
|
|
|
cache.loadedAt = time.Now()
|
|
|
|
|
|
cs.cache.Store(cache)
|
|
|
|
|
|
return cs
|
|
|
|
|
|
}
|