mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计
- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离 - 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价 - 模型限制:渠道可限制仅允许定价列表中的模型 - 计费模型来源:支持 requested/upstream 两种计费模型选择 - 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段 - Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计 - 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict - 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
This commit is contained in:
@@ -26,37 +26,37 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
Platform string `json:"platform" binding:"omitempty,max=50"`
|
||||
Models []string `json:"models" binding:"required,min=1,max=100"`
|
||||
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
||||
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
||||
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
|
||||
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
|
||||
Intervals []pricingIntervalRequest `json:"intervals"`
|
||||
Platform string `json:"platform" binding:"omitempty,max=50"`
|
||||
Models []string `json:"models" binding:"required,min=1,max=100"`
|
||||
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
||||
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
||||
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
|
||||
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
|
||||
Intervals []pricingIntervalRequest `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalRequest struct {
|
||||
@@ -72,31 +72,31 @@ type pricingIntervalRequest struct {
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
Intervals []pricingIntervalResponse `json:"intervals"`
|
||||
ID int64 `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
Intervals []pricingIntervalResponse `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalResponse struct {
|
||||
@@ -117,15 +117,15 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
return nil
|
||||
}
|
||||
resp := &channelResponse{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
resp.BillingModelSource = ch.BillingModelSource
|
||||
if resp.BillingModelSource == "" {
|
||||
@@ -298,9 +298,9 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
})
|
||||
@@ -331,8 +331,8 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
}
|
||||
|
||||
502
backend/internal/handler/admin/channel_handler_test.go
Normal file
502
backend/internal/handler/admin/channel_handler_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. channelToResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChannelToResponse_NilInput(t *testing.T) {
|
||||
require.Nil(t, channelToResponse(nil))
|
||||
}
|
||||
|
||||
func TestChannelToResponse_FullChannel(t *testing.T) {
|
||||
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
|
||||
ch := &service.Channel{
|
||||
ID: 42,
|
||||
Name: "test-channel",
|
||||
Description: "desc",
|
||||
Status: "active",
|
||||
BillingModelSource: "upstream",
|
||||
RestrictModels: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(time.Hour),
|
||||
GroupIDs: []int64{1, 2, 3},
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 10,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4"},
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.03),
|
||||
CacheWritePrice: float64Ptr(0.005),
|
||||
CacheReadPrice: float64Ptr(0.002),
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(42), resp.ID)
|
||||
require.Equal(t, "test-channel", resp.Name)
|
||||
require.Equal(t, "desc", resp.Description)
|
||||
require.Equal(t, "active", resp.Status)
|
||||
require.Equal(t, "upstream", resp.BillingModelSource)
|
||||
require.True(t, resp.RestrictModels)
|
||||
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
|
||||
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
|
||||
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
|
||||
|
||||
// model mapping
|
||||
require.Len(t, resp.ModelMapping, 1)
|
||||
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
|
||||
|
||||
// pricing
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
p := resp.ModelPricing[0]
|
||||
require.Equal(t, int64(10), p.ID)
|
||||
require.Equal(t, "openai", p.Platform)
|
||||
require.Equal(t, []string{"gpt-4"}, p.Models)
|
||||
require.Equal(t, "token", p.BillingMode)
|
||||
require.Equal(t, float64Ptr(0.01), p.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
|
||||
require.Empty(t, p.Intervals)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
|
||||
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
BillingModelSource: "",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
GroupIDs: nil,
|
||||
ModelMapping: nil,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Platform: "",
|
||||
BillingMode: "",
|
||||
Models: []string{"m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Equal(t, "requested", resp.BillingModelSource)
|
||||
require.NotNil(t, resp.GroupIDs)
|
||||
require.Empty(t, resp.GroupIDs)
|
||||
require.NotNil(t, resp.ModelMapping)
|
||||
require.Empty(t, resp.ModelMapping)
|
||||
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
|
||||
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_NilModels(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Models: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
require.NotNil(t, resp.ModelPricing[0].Models)
|
||||
require.Empty(t, resp.ModelPricing[0].Models)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_WithIntervals(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
Intervals: []service.PricingInterval{
|
||||
{
|
||||
ID: 100,
|
||||
MinTokens: 0,
|
||||
MaxTokens: intPtr(1000),
|
||||
TierLabel: "1K",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.02),
|
||||
CacheWritePrice: float64Ptr(0.003),
|
||||
CacheReadPrice: float64Ptr(0.001),
|
||||
PerRequestPrice: float64Ptr(0.1),
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
ID: 101,
|
||||
MinTokens: 1000,
|
||||
MaxTokens: nil,
|
||||
TierLabel: "unlimited",
|
||||
SortOrder: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
intervals := resp.ModelPricing[0].Intervals
|
||||
require.Len(t, intervals, 2)
|
||||
|
||||
iv0 := intervals[0]
|
||||
require.Equal(t, int64(100), iv0.ID)
|
||||
require.Equal(t, 0, iv0.MinTokens)
|
||||
require.Equal(t, intPtr(1000), iv0.MaxTokens)
|
||||
require.Equal(t, "1K", iv0.TierLabel)
|
||||
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
|
||||
require.Equal(t, 1, iv0.SortOrder)
|
||||
|
||||
iv1 := intervals[1]
|
||||
require.Equal(t, int64(101), iv1.ID)
|
||||
require.Equal(t, 1000, iv1.MinTokens)
|
||||
require.Nil(t, iv1.MaxTokens)
|
||||
require.Equal(t, "unlimited", iv1.TierLabel)
|
||||
require.Equal(t, 2, iv1.SortOrder)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_MultipleEntries(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "multi",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.003),
|
||||
OutputPrice: float64Ptr(0.015),
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4", "gpt-4o"},
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(1.0),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Platform: "gemini",
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
BillingMode: service.BillingModeImage,
|
||||
ImageOutputPrice: float64Ptr(0.05),
|
||||
PerRequestPrice: float64Ptr(0.2),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 3)
|
||||
|
||||
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
|
||||
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
|
||||
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
|
||||
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
|
||||
|
||||
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
|
||||
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
|
||||
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
|
||||
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
|
||||
|
||||
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
|
||||
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
|
||||
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
|
||||
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
|
||||
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. pricingRequestToService
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestPricingRequestToService_Defaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req channelModelPricingRequest
|
||||
wantField string // which default field to check
|
||||
wantValue string
|
||||
}{
|
||||
{
|
||||
name: "empty billing mode defaults to token",
|
||||
req: channelModelPricingRequest{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "",
|
||||
},
|
||||
wantField: "BillingMode",
|
||||
wantValue: string(service.BillingModeToken),
|
||||
},
|
||||
{
|
||||
name: "empty platform defaults to anthropic",
|
||||
req: channelModelPricingRequest{
|
||||
Models: []string{"m1"},
|
||||
Platform: "",
|
||||
},
|
||||
wantField: "Platform",
|
||||
wantValue: "anthropic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
|
||||
require.Len(t, result, 1)
|
||||
switch tt.wantField {
|
||||
case "BillingMode":
|
||||
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
|
||||
case "Platform":
|
||||
require.Equal(t, tt.wantValue, result[0].Platform)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_WithAllFields(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4", "gpt-4o"},
|
||||
BillingMode: "per_request",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.03),
|
||||
CacheWritePrice: float64Ptr(0.005),
|
||||
CacheReadPrice: float64Ptr(0.002),
|
||||
ImageOutputPrice: float64Ptr(0.04),
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
r := result[0]
|
||||
require.Equal(t, "openai", r.Platform)
|
||||
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
|
||||
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
|
||||
require.Equal(t, float64Ptr(0.01), r.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
|
||||
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_WithIntervals(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "per_request",
|
||||
Intervals: []pricingIntervalRequest{
|
||||
{
|
||||
MinTokens: 0,
|
||||
MaxTokens: intPtr(2000),
|
||||
TierLabel: "small",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.02),
|
||||
CacheWritePrice: float64Ptr(0.003),
|
||||
CacheReadPrice: float64Ptr(0.001),
|
||||
PerRequestPrice: float64Ptr(0.1),
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
MinTokens: 2000,
|
||||
MaxTokens: nil,
|
||||
TierLabel: "large",
|
||||
SortOrder: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
require.Len(t, result[0].Intervals, 2)
|
||||
|
||||
iv0 := result[0].Intervals[0]
|
||||
require.Equal(t, 0, iv0.MinTokens)
|
||||
require.Equal(t, intPtr(2000), iv0.MaxTokens)
|
||||
require.Equal(t, "small", iv0.TierLabel)
|
||||
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
|
||||
require.Equal(t, 1, iv0.SortOrder)
|
||||
|
||||
iv1 := result[0].Intervals[1]
|
||||
require.Equal(t, 2000, iv1.MinTokens)
|
||||
require.Nil(t, iv1.MaxTokens)
|
||||
require.Equal(t, "large", iv1.TierLabel)
|
||||
require.Equal(t, 2, iv1.SortOrder)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_EmptySlice(t *testing.T) {
|
||||
result := pricingRequestToService([]channelModelPricingRequest{})
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "token",
|
||||
// all price fields are nil by default
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
r := result[0]
|
||||
require.Nil(t, r.InputPrice)
|
||||
require.Nil(t, r.OutputPrice)
|
||||
require.Nil(t, r.CacheWritePrice)
|
||||
require.Nil(t, r.CacheReadPrice)
|
||||
require.Nil(t, r.ImageOutputPrice)
|
||||
require.Nil(t, r.PerRequestPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. validatePricingBillingMode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidatePricingBillingMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pricing []service.ChannelModelPricing
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "token mode - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModeToken},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request with price - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request with intervals - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
Intervals: []service.PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request no price no intervals - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModePerRequest},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "image with price - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModeImage,
|
||||
PerRequestPrice: float64Ptr(0.2),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "image no price no intervals - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModeImage},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list - valid",
|
||||
pricing: []service.ChannelModelPricing{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed modes with invalid image - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.01),
|
||||
},
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
{
|
||||
BillingMode: service.BillingModeImage,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validatePricingBillingMode(tt.pricing)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Per-request price or intervals required")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||
dim.Endpoint = c.Query("endpoint")
|
||||
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||
|
||||
// Additional filter conditions
|
||||
if v := c.Query("user_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.UserID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("api_key_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.APIKeyID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("account_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.AccountID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("request_type"); v != "" {
|
||||
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
|
||||
rtVal := int16(rt)
|
||||
dim.RequestType = &rtVal
|
||||
}
|
||||
}
|
||||
if v := c.Query("stream"); v != "" {
|
||||
if s, err := strconv.ParseBool(v); err == nil {
|
||||
dim.Stream = &s
|
||||
}
|
||||
}
|
||||
if v := c.Query("billing_type"); v != "" {
|
||||
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
|
||||
btVal := int8(bt)
|
||||
dim.BillingType = &btVal
|
||||
}
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||
|
||||
@@ -485,10 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -828,10 +825,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
|
||||
@@ -266,10 +266,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.cc.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -272,10 +272,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.responses.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -534,10 +534,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gemini_v1beta.models"),
|
||||
|
||||
@@ -278,10 +278,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
|
||||
@@ -391,10 +391,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMapping.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMapping.BillingModelSource,
|
||||
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -787,10 +784,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMappingMsg.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMappingMsg.BillingModelSource,
|
||||
ModelMappingChain: channelMappingMsg.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1298,10 +1292,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelID: channelMappingWS.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: channelMappingWS.BillingModelSource,
|
||||
ModelMappingChain: channelMappingWS.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -125,6 +125,13 @@ func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []s
|
||||
return r.countValue, nil
|
||||
}
|
||||
|
||||
func (r *stubSoraGenRepo) CountByStorageType(_ context.Context, _ string, _ []string) (int64, error) {
|
||||
if r.countErr != nil {
|
||||
return 0, r.countErr
|
||||
}
|
||||
return r.countValue, nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
|
||||
@@ -1657,8 +1664,8 @@ func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("ok")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
||||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||||
@@ -1679,8 +1686,8 @@ func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("ok")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
||||
@@ -1704,8 +1711,8 @@ func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
|
||||
}))
|
||||
defer badSource.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||||
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
|
||||
@@ -1719,8 +1726,8 @@ func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("fail")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||||
@@ -1736,8 +1743,8 @@ func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("fail-second")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
||||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||||
@@ -1808,7 +1815,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("fail")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
@@ -1821,8 +1828,8 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
|
||||
}
|
||||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||||
h := &SoraClientHandler{
|
||||
s3Storage: s3Storage,
|
||||
mediaStorage: mediaStorage,
|
||||
objectStorage: objectStorage,
|
||||
mediaStorage: mediaStorage,
|
||||
}
|
||||
|
||||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||||
@@ -1846,9 +1853,9 @@ func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
|
||||
StorageType: "upstream",
|
||||
MediaURL: sourceServer.URL + "/v.mp4",
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -1872,9 +1879,9 @@ func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
|
||||
StorageType: "upstream",
|
||||
MediaURL: expiredServer.URL + "/v.mp4",
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -1896,9 +1903,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
|
||||
StorageType: "upstream",
|
||||
MediaURL: sourceServer.URL + "/v.mp4",
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -1906,7 +1913,7 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
resp := parseResponse(t, rec)
|
||||
data := resp["data"].(map[string]any)
|
||||
require.Contains(t, data["message"], "S3")
|
||||
require.Contains(t, data["message"], "云存储")
|
||||
require.NotEmpty(t, data["object_key"])
|
||||
// 验证记录已更新为 S3 存储
|
||||
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
|
||||
@@ -1928,9 +1935,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
|
||||
sourceServer.URL + "/v2.mp4",
|
||||
},
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -1956,7 +1963,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
|
||||
StorageType: "upstream",
|
||||
MediaURL: sourceServer.URL + "/v.mp4",
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
userRepo := newStubUserRepoForHandler()
|
||||
@@ -1966,7 +1973,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
|
||||
SoraStorageUsedBytes: 0,
|
||||
}
|
||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -1990,9 +1997,9 @@ func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
|
||||
}
|
||||
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
|
||||
repo.updateErr = fmt.Errorf("db error")
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -2007,8 +2014,8 @@ func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("fail")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||||
h.GetStorageStatus(c)
|
||||
@@ -2023,8 +2030,8 @@ func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("ok")
|
||||
defer fakeS3.Close()
|
||||
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||||
h.GetStorageStatus(c)
|
||||
@@ -2453,7 +2460,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
|
||||
},
|
||||
}
|
||||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
|
||||
userRepo := newStubUserRepoForHandler()
|
||||
userRepo.users[1] = &service.User{
|
||||
@@ -2465,7 +2472,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
|
||||
genService: genService,
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
s3Storage: s3Storage,
|
||||
objectStorage: objectStorage,
|
||||
quotaService: quotaService,
|
||||
}
|
||||
|
||||
@@ -2515,7 +2522,7 @@ func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
|
||||
// ==================== cleanupStoredMedia 直接测试 ====================
|
||||
|
||||
func TestCleanupStoredMedia_S3Path(t *testing.T) {
|
||||
// S3 清理路径:s3Storage 为 nil 时不 panic
|
||||
// S3 清理路径:objectStorage 为 nil 时不 panic
|
||||
h := &SoraClientHandler{}
|
||||
// 不应 panic
|
||||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
||||
@@ -2962,7 +2969,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
|
||||
StorageType: "upstream",
|
||||
MediaURL: sourceServer.URL + "/v.mp4",
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// 用户配额已满
|
||||
@@ -2973,7 +2980,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
|
||||
SoraStorageUsedBytes: 10,
|
||||
}
|
||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -2995,13 +3002,13 @@ func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
|
||||
StorageType: "upstream",
|
||||
MediaURL: sourceServer.URL + "/v.mp4",
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
|
||||
userRepo := newStubUserRepoForHandler()
|
||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -3022,9 +3029,9 @@ func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
|
||||
MediaURL: "",
|
||||
MediaURLs: []string{},
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -3049,9 +3056,9 @@ func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
|
||||
MediaURL: sourceServer.URL + "/v1.mp4",
|
||||
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
|
||||
}
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -3074,7 +3081,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
||||
MediaURL: sourceServer.URL + "/v.mp4",
|
||||
}
|
||||
repo.updateErr = fmt.Errorf("db error")
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
userRepo := newStubUserRepoForHandler()
|
||||
@@ -3084,7 +3091,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
||||
SoraStorageUsedBytes: 0,
|
||||
}
|
||||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||||
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
|
||||
|
||||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||||
@@ -3097,8 +3104,8 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
||||
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("ok")
|
||||
defer fakeS3.Close()
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
|
||||
}
|
||||
@@ -3106,8 +3113,8 @@ func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
|
||||
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
|
||||
fakeS3 := newFakeS3Server("fail")
|
||||
defer fakeS3.Close()
|
||||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||||
objectStorage := newS3StorageForHandler(fakeS3.URL)
|
||||
h := &SoraClientHandler{objectStorage: objectStorage}
|
||||
|
||||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
//
|
||||
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
|
||||
@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
|
||||
ModelType string // "requested", "upstream", or "mapping"
|
||||
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||
EndpointType string // "inbound", "upstream", or "path"
|
||||
// Additional filter conditions
|
||||
UserID int64 // filter by user_id (>0 to enable)
|
||||
APIKeyID int64 // filter by api_key_id (>0 to enable)
|
||||
AccountID int64 // filter by account_id (>0 to enable)
|
||||
RequestType *int16 // filter by request_type (non-nil to enable)
|
||||
Stream *bool // filter by stream flag (non-nil to enable)
|
||||
BillingType *int8 // filter by billing_type (non-nil to enable)
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -274,7 +275,8 @@ func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pr
|
||||
|
||||
// isUniqueViolation 检查 pq 唯一约束违反错误
|
||||
func isUniqueViolation(err error) bool {
|
||||
if pqErr, ok := err.(*pq.Error); ok {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr != nil {
|
||||
return pqErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
|
||||
227
backend/internal/repository/channel_repo_test.go
Normal file
227
backend/internal/repository/channel_repo_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- marshalModelMapping ---
|
||||
|
||||
func TestMarshalModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]map[string]string
|
||||
wantJSON string // expected JSON output (exact match)
|
||||
}{
|
||||
{
|
||||
name: "empty map",
|
||||
input: map[string]map[string]string{},
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "nil map",
|
||||
input: nil,
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "populated map",
|
||||
input: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested values",
|
||||
input: map[string]map[string]string{
|
||||
"openai": {"*": "gpt-5.4"},
|
||||
"anthropic": {"claude-old": "claude-new"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := marshalModelMapping(tt.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantJSON != "" {
|
||||
require.Equal(t, []byte(tt.wantJSON), result)
|
||||
} else {
|
||||
// round-trip: unmarshal and compare with input
|
||||
var parsed map[string]map[string]string
|
||||
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||
require.Equal(t, tt.input, parsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- unmarshalModelMapping ---
|
||||
|
||||
func TestUnmarshalModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantNil bool
|
||||
want map[string]map[string]string
|
||||
}{
|
||||
{
|
||||
name: "nil data",
|
||||
input: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
input: []byte{},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: []byte("not-json"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "type error - number",
|
||||
input: []byte("42"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "type error - array",
|
||||
input: []byte("[1,2,3]"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "valid JSON",
|
||||
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
|
||||
want: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
"anthropic": {"old": "new"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
input: []byte("{}"),
|
||||
want: map[string]map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := unmarshalModelMapping(tt.input)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.want, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- escapeLike ---
|
||||
|
||||
func TestEscapeLike(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no special chars",
|
||||
input: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "backslash",
|
||||
input: `a\b`,
|
||||
want: `a\\b`,
|
||||
},
|
||||
{
|
||||
name: "percent",
|
||||
input: "50%",
|
||||
want: `50\%`,
|
||||
},
|
||||
{
|
||||
name: "underscore",
|
||||
input: "a_b",
|
||||
want: `a\_b`,
|
||||
},
|
||||
{
|
||||
name: "all special chars",
|
||||
input: `a\b%c_d`,
|
||||
want: `a\\b\%c\_d`,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "consecutive special chars",
|
||||
input: "%_%",
|
||||
want: `\%\_\%`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, escapeLike(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- isUniqueViolation ---
|
||||
|
||||
func TestIsUniqueViolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "unique violation code 23505",
|
||||
err: &pq.Error{Code: "23505"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different pq error code",
|
||||
err: &pq.Error{Code: "23503"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-pq error",
|
||||
err: errors.New("some generic error"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "typed nil pq.Error",
|
||||
err: func() error {
|
||||
var pqErr *pq.Error
|
||||
return pqErr
|
||||
}(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "bare nil",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wrapped pq error with 23505",
|
||||
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, isUniqueViolation(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3144,6 +3144,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
|
||||
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
|
||||
args = append(args, dim.Endpoint)
|
||||
}
|
||||
if dim.UserID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
|
||||
args = append(args, dim.UserID)
|
||||
}
|
||||
if dim.APIKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
|
||||
args = append(args, dim.APIKeyID)
|
||||
}
|
||||
if dim.AccountID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
|
||||
args = append(args, dim.AccountID)
|
||||
}
|
||||
if dim.RequestType != nil {
|
||||
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
|
||||
args = append(args, *dim.RequestType)
|
||||
}
|
||||
if dim.Stream != nil {
|
||||
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
|
||||
args = append(args, *dim.Stream)
|
||||
}
|
||||
if dim.BillingType != nil {
|
||||
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
|
||||
args = append(args, *dim.BillingType)
|
||||
}
|
||||
|
||||
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
|
||||
if limit > 0 {
|
||||
|
||||
@@ -80,6 +80,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
sqlmock.AnyArg(), // upstream_endpoint
|
||||
log.CacheTTLOverridden,
|
||||
sqlmock.AnyArg(), // channel_id
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||
@@ -153,6 +157,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
sqlmock.AnyArg(), // channel_id
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
@@ -463,6 +471,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -506,6 +518,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -549,6 +565,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -51,15 +51,15 @@ type Channel struct {
|
||||
type ChannelModelPricing struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Platform string // 所属平台(anthropic/openai/gemini/...)
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
PerRequestPrice *float64 // 默认按次计费价格(USD)
|
||||
Platform string // 所属平台(anthropic/openai/gemini/...)
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
PerRequestPrice *float64 // 默认按次计费价格(USD)
|
||||
Intervals []PricingInterval // 区间定价列表
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
@@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel {
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
|
||||
type ChannelUsageFields struct {
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -17,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrGroupAlreadyInChannel = infraerrors.Conflict(
|
||||
"GROUP_ALREADY_IN_CHANNEL",
|
||||
"one or more groups already belong to another channel",
|
||||
@@ -81,12 +80,12 @@ type wildcardMappingEntry struct {
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
@@ -118,9 +117,19 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
|
||||
return reqModel + "→" + r.MappedModel
|
||||
}
|
||||
|
||||
// ToUsageFields 将渠道映射结果转为使用记录字段
|
||||
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
|
||||
return ChannelUsageFields{
|
||||
ChannelID: r.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: r.BillingModelSource,
|
||||
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
@@ -177,14 +186,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
@@ -205,14 +214,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
|
||||
for gpKey, entries := range cache.wildcardByGroupPlatform {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardByGroupPlatform[gpKey] = entries
|
||||
}
|
||||
for gpKey, entries := range cache.wildcardMappingByGP {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardMappingByGP[gpKey] = entries
|
||||
}
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() {
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
|
||||
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardByGroupPlatform[gpKey]
|
||||
@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先)
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardMappingByGP[gpKey]
|
||||
@@ -479,15 +476,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceRequested
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复)
|
||||
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
|
||||
seen := make(map[string]bool)
|
||||
// modelEntry 表示一个模型模式条目(用于冲突检测)
|
||||
type modelEntry struct {
|
||||
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4")
|
||||
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
|
||||
wildcard bool
|
||||
}
|
||||
|
||||
// conflictsBetween 检查两个模型模式是否冲突
|
||||
func conflictsBetween(a, b modelEntry) bool {
|
||||
switch {
|
||||
case !a.wildcard && !b.wildcard:
|
||||
return a.prefix == b.prefix
|
||||
case a.wildcard && !b.wildcard:
|
||||
return strings.HasPrefix(b.prefix, a.prefix)
|
||||
case !a.wildcard && b.wildcard:
|
||||
return strings.HasPrefix(a.prefix, b.prefix)
|
||||
default:
|
||||
return strings.HasPrefix(a.prefix, b.prefix) ||
|
||||
strings.HasPrefix(b.prefix, a.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// toModelEntry 将模型名转换为 modelEntry
|
||||
func toModelEntry(pattern string) modelEntry {
|
||||
lower := strings.ToLower(pattern)
|
||||
isWild := strings.HasSuffix(lower, "*")
|
||||
prefix := lower
|
||||
if isWild {
|
||||
prefix = strings.TrimSuffix(lower, "*")
|
||||
}
|
||||
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
|
||||
}
|
||||
|
||||
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
|
||||
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
|
||||
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
|
||||
byPlatform := make(map[string][]modelEntry)
|
||||
for _, p := range pricingList {
|
||||
for _, model := range p.Models {
|
||||
key := p.Platform + ":" + strings.ToLower(model)
|
||||
if seen[key] {
|
||||
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
|
||||
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
|
||||
}
|
||||
}
|
||||
for platform, entries := range byPlatform {
|
||||
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
|
||||
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
|
||||
for platform, platformMapping := range mapping {
|
||||
entries := make([]modelEntry, 0, len(platformMapping))
|
||||
for src := range platformMapping {
|
||||
entries = append(entries, toModelEntry(src))
|
||||
}
|
||||
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
|
||||
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
|
||||
for i := 0; i < len(entries); i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if conflictsBetween(entries[i], entries[j]) {
|
||||
return infraerrors.BadRequest(errCode,
|
||||
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
|
||||
label, entries[i].pattern, entries[j].pattern, platform))
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
1890
backend/internal/service/channel_service_test.go
Normal file
1890
backend/internal/service/channel_service_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,13 +8,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func channelTestPtrFloat64(v float64) *float64 { return &v }
|
||||
func channelTestPtrInt(v int) *int { return &v }
|
||||
|
||||
func TestGetModelPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
|
||||
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
|
||||
},
|
||||
}
|
||||
@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) {
|
||||
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -73,23 +70,23 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) {
|
||||
func TestGetIntervalForContext(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
}{
|
||||
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
|
||||
{"first interval", 50000, testPtrFloat64(1e-6), false},
|
||||
// (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个
|
||||
{"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false},
|
||||
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
|
||||
// 128001 > 128000,匹配第二个区间
|
||||
{"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
|
||||
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
|
||||
// (0, max] — 0 不匹配任何区间(左开)
|
||||
{"zero tokens: no match", 0, nil, true},
|
||||
}
|
||||
@@ -110,11 +107,11 @@ func TestGetIntervalForContext(t *testing.T) {
|
||||
func TestGetIntervalForContext_NoMatch(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
|
||||
},
|
||||
}
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
|
||||
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
|
||||
}
|
||||
@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) {
|
||||
func TestGetTierByLabel(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) {
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"model-a"},
|
||||
InputPrice: channelTestPtrFloat64(5e-6),
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) {
|
||||
cloned.Intervals[0].TierLabel = "hacked"
|
||||
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
|
||||
}
|
||||
|
||||
// --- BillingMode.IsValid ---
|
||||
|
||||
func TestBillingModeIsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode BillingMode
|
||||
want bool
|
||||
}{
|
||||
{"token", BillingModeToken, true},
|
||||
{"per_request", BillingModePerRequest, true},
|
||||
{"image", BillingModeImage, true},
|
||||
{"empty", BillingMode(""), true},
|
||||
{"unknown", BillingMode("unknown"), false},
|
||||
{"random", BillingMode("xyz"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.mode.IsValid())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Channel.IsActive ---
|
||||
|
||||
func TestChannelIsActive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
want bool
|
||||
}{
|
||||
{"active", StatusActive, true},
|
||||
{"disabled", "disabled", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := &Channel{Status: tt.status}
|
||||
require.Equal(t, tt.want, ch.IsActive())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ChannelModelPricing.Clone edge cases ---
|
||||
|
||||
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Models)
|
||||
})
|
||||
|
||||
t.Run("nil intervals", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Intervals: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Intervals)
|
||||
})
|
||||
|
||||
t.Run("empty models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: []string{}}
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned.Models)
|
||||
require.Empty(t, cloned.Models)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Channel.Clone edge cases ---
|
||||
|
||||
func TestChannelClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil model mapping", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelMapping: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelMapping)
|
||||
})
|
||||
|
||||
t.Run("nil model pricing", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelPricing: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelPricing)
|
||||
})
|
||||
|
||||
t.Run("deep copy model mapping", func(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
}
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify the cloned nested map
|
||||
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
|
||||
|
||||
// Original must remain unchanged
|
||||
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7407,11 +7407,7 @@ type RecordUsageInput struct {
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
|
||||
// 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
@@ -7940,11 +7936,7 @@ type RecordUsageLongContextInput struct {
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
|
||||
// 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
|
||||
@@ -4,14 +4,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resolverPtrFloat64(v float64) *float64 { return &v }
|
||||
func resolverPtrInt(v int) *int { return &v }
|
||||
|
||||
func newTestBillingServiceForResolver() *BillingService {
|
||||
bs := &BillingService{
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
@@ -83,8 +81,8 @@ func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
|
||||
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
|
||||
SupportsCacheBreakdown: true,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,7 +106,7 @@ func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)},
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -123,8 +121,8 @@ func TestGetRequestTierPrice(t *testing.T) {
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)},
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -140,8 +138,8 @@ func TestGetRequestTierPriceByContext(t *testing.T) {
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -162,3 +160,428 @@ func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
|
||||
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Channel override tests — exercises applyChannelOverrides via Resolve
|
||||
// ===========================================================================
|
||||
|
||||
// helper: creates a resolver wired to a ChannelService that returns the given
|
||||
// channel (active, groupID=100, platform=anthropic) with the specified pricing.
|
||||
func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver {
|
||||
t.Helper()
|
||||
const groupID = 100
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||||
return []Channel{{
|
||||
ID: 1,
|
||||
Name: "test-channel",
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelPricing: pricing,
|
||||
}}, nil
|
||||
},
|
||||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||||
return map[int64]string{groupID: "anthropic"}, nil
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
return NewModelPricingResolver(cs, bs)
|
||||
}
|
||||
|
||||
// groupIDPtr returns a pointer to groupID 100 (the test constant).
|
||||
func groupIDPtr() *int64 { v := int64(100); return &v }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Token mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(10e-6),
|
||||
OutputPrice: testPtrFloat64(50e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) {
|
||||
// Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback).
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(20e-6),
|
||||
// OutputPrice intentionally nil
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
// InputPrice overridden by channel
|
||||
require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
// OutputPrice kept from base (fallback: 15e-6)
|
||||
require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.Len(t, resolved.Intervals, 2)
|
||||
|
||||
// GetIntervalPricing should use channel intervals
|
||||
iv := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, iv)
|
||||
require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12)
|
||||
|
||||
iv2 := r.GetIntervalPricing(resolved, 200000)
|
||||
require.NotNil(t, iv2)
|
||||
require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) {
|
||||
// Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"unknown-model-xyz"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(7e-6),
|
||||
OutputPrice: testPtrFloat64(21e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "unknown-model-xyz",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
// BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Per-request mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_PerRequest(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModePerRequest, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 2)
|
||||
|
||||
// Verify tier lookups
|
||||
require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
||||
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) {
|
||||
// PerRequestPrice nil → DefaultPerRequestPrice stays 0.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModePerRequest,
|
||||
// PerRequestPrice intentionally nil
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModePerRequest, resolved.Mode)
|
||||
require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. Image mode overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_Image(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.08),
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeImage, resolved.Mode)
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12)
|
||||
require.Len(t, resolved.RequestTiers, 3)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeImage,
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
||||
require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. Source tracking & default mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) {
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(1e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
}
|
||||
|
||||
func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) {
|
||||
// Channel pricing with empty BillingMode → defaults to BillingModeToken.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: "", // intentionally empty
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
require.Equal(t, "channel", resolved.Source)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 5. GetIntervalPricing integration after channel override
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) {
|
||||
// Channel provides intervals that override the base pricing path.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)},
|
||||
{MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
// Token count 50000 matches first interval
|
||||
pricing := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
|
||||
// Token count 150000 matches second interval
|
||||
pricing2 := r.GetIntervalPricing(resolved, 150000)
|
||||
require.NotNil(t, pricing2)
|
||||
require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) {
|
||||
// Channel intervals don't match token count → falls back to BasePricing.
|
||||
r := newResolverWithChannel(t, []ChannelModelPricing{{
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: BillingModeToken,
|
||||
Intervals: []PricingInterval{
|
||||
// Only covers tokens > 50000
|
||||
{MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)},
|
||||
},
|
||||
}})
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: groupIDPtr(),
|
||||
})
|
||||
|
||||
// Token count 1000 doesn't match any interval (1000 <= 50000 minTokens)
|
||||
pricing := r.GetIntervalPricing(resolved, 1000)
|
||||
// Should fall back to BasePricing (from the billing service fallback)
|
||||
require.NotNil(t, pricing)
|
||||
require.Equal(t, resolved.BasePricing, pricing)
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 6. Error path tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
|
||||
// When ListAll returns an error, the ChannelService cache build fails.
|
||||
// Resolve should gracefully fall back to base pricing without panicking.
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||||
return nil, errors.New("database unavailable")
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(cs, bs)
|
||||
|
||||
gid := int64(100)
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: &gid,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
// Should NOT panic, should NOT have source "channel"
|
||||
require.NotEqual(t, "channel", resolved.Source)
|
||||
// Base pricing should still be present (from BillingService fallback)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// 7. GetRequestTierPriceByContext boundary tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: nil, // empty
|
||||
}
|
||||
|
||||
price := r.GetRequestTierPriceByContext(resolved, 50000)
|
||||
require.InDelta(t, 0.0, price, 1e-12)
|
||||
|
||||
// Also test with explicit empty slice
|
||||
resolved2 := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{},
|
||||
}
|
||||
|
||||
price2 := r.GetRequestTierPriceByContext(resolved2, 50000)
|
||||
require.InDelta(t, 0.0, price2, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
// totalContextTokens = 128000 exactly:
|
||||
// FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens
|
||||
// For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval
|
||||
price := r.GetRequestTierPriceByContext(resolved, 128000)
|
||||
require.InDelta(t, 0.05, price, 1e-12)
|
||||
|
||||
// totalContextTokens = 128001 should match second interval
|
||||
// For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match
|
||||
// For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches
|
||||
price2 := r.GetRequestTierPriceByContext(resolved, 128001)
|
||||
require.InDelta(t, 0.10, price2, 1e-12)
|
||||
}
|
||||
|
||||
@@ -4146,10 +4146,7 @@ type OpenAIRecordUsageInput struct {
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
ChannelID int64
|
||||
OriginalModel string
|
||||
BillingModelSource string
|
||||
ModelMappingChain string
|
||||
ChannelUsageFields
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
|
||||
15
backend/internal/service/testhelpers_test.go
Normal file
15
backend/internal/service/testhelpers_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
// testPtrFloat64 returns a pointer to the given float64 value.
|
||||
func testPtrFloat64(v float64) *float64 { return &v }
|
||||
|
||||
// testPtrInt returns a pointer to the given int value.
|
||||
func testPtrInt(v int) *int { return &v }
|
||||
|
||||
// testPtrString returns a pointer to the given string value.
|
||||
func testPtrString(v string) *string { return &v }
|
||||
|
||||
// testPtrBool returns a pointer to the given bool value.
|
||||
func testPtrBool(v bool) *bool { return &v }
|
||||
Reference in New Issue
Block a user