mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计
- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离 - 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价 - 模型限制:渠道可限制仅允许定价列表中的模型 - 计费模型来源:支持 requested/upstream 两种计费模型选择 - 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段 - Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计 - 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict - 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -274,7 +275,8 @@ func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pr
|
||||
|
||||
// isUniqueViolation 检查 pq 唯一约束违反错误
|
||||
func isUniqueViolation(err error) bool {
|
||||
if pqErr, ok := err.(*pq.Error); ok {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr != nil {
|
||||
return pqErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
|
||||
227
backend/internal/repository/channel_repo_test.go
Normal file
227
backend/internal/repository/channel_repo_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- marshalModelMapping ---
|
||||
|
||||
func TestMarshalModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]map[string]string
|
||||
wantJSON string // expected JSON output (exact match)
|
||||
}{
|
||||
{
|
||||
name: "empty map",
|
||||
input: map[string]map[string]string{},
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "nil map",
|
||||
input: nil,
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "populated map",
|
||||
input: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested values",
|
||||
input: map[string]map[string]string{
|
||||
"openai": {"*": "gpt-5.4"},
|
||||
"anthropic": {"claude-old": "claude-new"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := marshalModelMapping(tt.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantJSON != "" {
|
||||
require.Equal(t, []byte(tt.wantJSON), result)
|
||||
} else {
|
||||
// round-trip: unmarshal and compare with input
|
||||
var parsed map[string]map[string]string
|
||||
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||
require.Equal(t, tt.input, parsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- unmarshalModelMapping ---
|
||||
|
||||
func TestUnmarshalModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantNil bool
|
||||
want map[string]map[string]string
|
||||
}{
|
||||
{
|
||||
name: "nil data",
|
||||
input: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
input: []byte{},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: []byte("not-json"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "type error - number",
|
||||
input: []byte("42"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "type error - array",
|
||||
input: []byte("[1,2,3]"),
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "valid JSON",
|
||||
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
|
||||
want: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
"anthropic": {"old": "new"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
input: []byte("{}"),
|
||||
want: map[string]map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := unmarshalModelMapping(tt.input)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.want, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- escapeLike ---
|
||||
|
||||
func TestEscapeLike(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no special chars",
|
||||
input: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "backslash",
|
||||
input: `a\b`,
|
||||
want: `a\\b`,
|
||||
},
|
||||
{
|
||||
name: "percent",
|
||||
input: "50%",
|
||||
want: `50\%`,
|
||||
},
|
||||
{
|
||||
name: "underscore",
|
||||
input: "a_b",
|
||||
want: `a\_b`,
|
||||
},
|
||||
{
|
||||
name: "all special chars",
|
||||
input: `a\b%c_d`,
|
||||
want: `a\\b\%c\_d`,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "consecutive special chars",
|
||||
input: "%_%",
|
||||
want: `\%\_\%`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, escapeLike(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- isUniqueViolation ---
|
||||
|
||||
func TestIsUniqueViolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "unique violation code 23505",
|
||||
err: &pq.Error{Code: "23505"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different pq error code",
|
||||
err: &pq.Error{Code: "23503"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-pq error",
|
||||
err: errors.New("some generic error"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "typed nil pq.Error",
|
||||
err: func() error {
|
||||
var pqErr *pq.Error
|
||||
return pqErr
|
||||
}(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "bare nil",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wrapped pq error with 23505",
|
||||
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, isUniqueViolation(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3144,6 +3144,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
|
||||
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
|
||||
args = append(args, dim.Endpoint)
|
||||
}
|
||||
if dim.UserID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
|
||||
args = append(args, dim.UserID)
|
||||
}
|
||||
if dim.APIKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
|
||||
args = append(args, dim.APIKeyID)
|
||||
}
|
||||
if dim.AccountID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
|
||||
args = append(args, dim.AccountID)
|
||||
}
|
||||
if dim.RequestType != nil {
|
||||
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
|
||||
args = append(args, *dim.RequestType)
|
||||
}
|
||||
if dim.Stream != nil {
|
||||
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
|
||||
args = append(args, *dim.Stream)
|
||||
}
|
||||
if dim.BillingType != nil {
|
||||
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
|
||||
args = append(args, *dim.BillingType)
|
||||
}
|
||||
|
||||
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
|
||||
if limit > 0 {
|
||||
|
||||
@@ -80,6 +80,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
sqlmock.AnyArg(), // upstream_endpoint
|
||||
log.CacheTTLOverridden,
|
||||
sqlmock.AnyArg(), // channel_id
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||
@@ -153,6 +157,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
sqlmock.AnyArg(), // channel_id
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
@@ -463,6 +471,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -506,6 +518,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -549,6 +565,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user