mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
feat(channel): 通配符定价匹配 + OpenAI BillingModelSource + 按次价格校验 + 用户端计费模式展示
- 定价查找支持通配符(suffix *),最长前缀优先匹配 - 模型限制(restrict_models)同样支持通配符匹配 - OpenAI 网关接入渠道映射/BillingModelSource/模型限制 - 按次/图片计费模式创建时强制要求价格或层级(前后端) - 用户使用记录列表增加计费模式 badge 列
This commit is contained in:
@@ -180,7 +180,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
|
|||||||
@@ -276,11 +276,21 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pricing := pricingRequestToService(req.ModelPricing)
|
||||||
|
for _, p := range pricing {
|
||||||
|
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
||||||
|
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||||
|
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ModelPricing: pricingRequestToService(req.ModelPricing),
|
ModelPricing: pricing,
|
||||||
ModelMapping: req.ModelMapping,
|
ModelMapping: req.ModelMapping,
|
||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
@@ -319,6 +329,14 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
|
for _, p := range pricing {
|
||||||
|
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
||||||
|
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||||
|
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
input.ModelPricing = &pricing
|
input.ModelPricing = &pricing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -185,6 +185,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射
|
||||||
|
var channelMapping service.ChannelMappingResult
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 渠道模型限制检查
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
|
||||||
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||||
return
|
return
|
||||||
@@ -379,6 +393,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMapping.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMapping.BillingModelSource,
|
||||||
|
ModelMappingChain: func() string {
|
||||||
|
if !channelMapping.Mapped {
|
||||||
|
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
|
||||||
|
return reqModel + "→" + result.UpstreamModel
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel {
|
||||||
|
return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel
|
||||||
|
}
|
||||||
|
return reqModel + "→" + channelMapping.MappedModel
|
||||||
|
}(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.responses"),
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
@@ -549,6 +578,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射
|
||||||
|
var channelMappingMsg service.ChannelMappingResult
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
channelMappingMsg = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 渠道模型限制检查
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
|
||||||
|
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
@@ -759,6 +802,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMappingMsg.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMappingMsg.BillingModelSource,
|
||||||
|
ModelMappingChain: func() string {
|
||||||
|
if !channelMappingMsg.Mapped {
|
||||||
|
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
|
||||||
|
return reqModel + "→" + result.UpstreamModel
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if result.UpstreamModel != "" && result.UpstreamModel != channelMappingMsg.MappedModel {
|
||||||
|
return reqModel + "→" + channelMappingMsg.MappedModel + "→" + result.UpstreamModel
|
||||||
|
}
|
||||||
|
return reqModel + "→" + channelMappingMsg.MappedModel
|
||||||
|
}(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.messages"),
|
zap.String("component", "handler.openai_gateway.messages"),
|
||||||
@@ -1101,6 +1159,20 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射
|
||||||
|
var channelMappingWS service.ChannelMappingResult
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
channelMappingWS = h.gatewayService.ResolveChannelMapping(ctx, *apiKey.GroupID, reqModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 渠道模型限制检查
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
if h.gatewayService.IsModelRestricted(ctx, *apiKey.GroupID, reqModel) {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var currentUserRelease func()
|
var currentUserRelease func()
|
||||||
var currentAccountRelease func()
|
var currentAccountRelease func()
|
||||||
releaseTurnSlots := func() {
|
releaseTurnSlots := func() {
|
||||||
@@ -1259,6 +1331,21 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMappingWS.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMappingWS.BillingModelSource,
|
||||||
|
ModelMappingChain: func() string {
|
||||||
|
if !channelMappingWS.Mapped {
|
||||||
|
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
|
||||||
|
return reqModel + "→" + result.UpstreamModel
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if result.UpstreamModel != "" && result.UpstreamModel != channelMappingWS.MappedModel {
|
||||||
|
return reqModel + "→" + channelMappingWS.MappedModel + "→" + result.UpstreamModel
|
||||||
|
}
|
||||||
|
return reqModel + "→" + channelMappingWS.MappedModel
|
||||||
|
}(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -57,13 +58,26 @@ type channelModelKey struct {
|
|||||||
model string // lowercase
|
model string // lowercase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// channelGroupPlatformKey 通配符定价缓存键
|
||||||
|
type channelGroupPlatformKey struct {
|
||||||
|
groupID int64
|
||||||
|
platform string
|
||||||
|
}
|
||||||
|
|
||||||
|
// wildcardPricingEntry 通配符定价条目
|
||||||
|
type wildcardPricingEntry struct {
|
||||||
|
prefix string
|
||||||
|
pricing *ChannelModelPricing
|
||||||
|
}
|
||||||
|
|
||||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||||
type channelCache struct {
|
type channelCache struct {
|
||||||
// 热路径查找
|
// 热路径查找
|
||||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
|
||||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||||
groupPlatform map[int64]string // groupID → platform
|
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||||
|
groupPlatform map[int64]string // groupID → platform
|
||||||
|
|
||||||
// 冷路径(CRUD 操作)
|
// 冷路径(CRUD 操作)
|
||||||
byID map[int64]*Channel
|
byID map[int64]*Channel
|
||||||
@@ -137,12 +151,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||||
slog.Warn("failed to build channel cache", "error", err)
|
slog.Warn("failed to build channel cache", "error", err)
|
||||||
errorCache := &channelCache{
|
errorCache := &channelCache{
|
||||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||||
mappingByGroupModel: make(map[channelModelKey]string),
|
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||||
channelByGroupID: make(map[int64]*Channel),
|
mappingByGroupModel: make(map[channelModelKey]string),
|
||||||
groupPlatform: make(map[int64]string),
|
channelByGroupID: make(map[int64]*Channel),
|
||||||
byID: make(map[int64]*Channel),
|
groupPlatform: make(map[int64]string),
|
||||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
byID: make(map[int64]*Channel),
|
||||||
|
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||||
}
|
}
|
||||||
s.cache.Store(errorCache)
|
s.cache.Store(errorCache)
|
||||||
return nil, fmt.Errorf("list all channels: %w", err)
|
return nil, fmt.Errorf("list all channels: %w", err)
|
||||||
@@ -163,12 +178,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
cache := &channelCache{
|
cache := &channelCache{
|
||||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||||
mappingByGroupModel: make(map[channelModelKey]string),
|
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||||
channelByGroupID: make(map[int64]*Channel),
|
mappingByGroupModel: make(map[channelModelKey]string),
|
||||||
groupPlatform: groupPlatforms,
|
channelByGroupID: make(map[int64]*Channel),
|
||||||
byID: make(map[int64]*Channel, len(channels)),
|
groupPlatform: groupPlatforms,
|
||||||
loadedAt: time.Now(),
|
byID: make(map[int64]*Channel, len(channels)),
|
||||||
|
loadedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
@@ -187,8 +203,18 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
continue // 跳过非本平台的定价
|
continue // 跳过非本平台的定价
|
||||||
}
|
}
|
||||||
for _, model := range pricing.Models {
|
for _, model := range pricing.Models {
|
||||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
if strings.HasSuffix(model, "*") {
|
||||||
cache.pricingByGroupModel[key] = pricing
|
// 通配符模型 → 存入 wildcardByGroupPlatform
|
||||||
|
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||||
|
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||||
|
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
||||||
|
prefix: prefix,
|
||||||
|
pricing: pricing,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||||
|
cache.pricingByGroupModel[key] = pricing
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,6 +228,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
|
||||||
|
for gpKey, entries := range cache.wildcardByGroupPlatform {
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||||
|
})
|
||||||
|
cache.wildcardByGroupPlatform[gpKey] = entries
|
||||||
|
}
|
||||||
|
|
||||||
s.cache.Store(cache)
|
s.cache.Store(cache)
|
||||||
return cache, nil
|
return cache, nil
|
||||||
}
|
}
|
||||||
@@ -212,6 +246,18 @@ func (s *ChannelService) invalidateCache() {
|
|||||||
s.cacheSF.Forget("channel_cache")
|
s.cacheSF.Forget("channel_cache")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
|
||||||
|
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
|
||||||
|
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||||
|
wildcards := c.wildcardByGroupPlatform[gpKey]
|
||||||
|
for _, wc := range wildcards {
|
||||||
|
if strings.HasPrefix(modelLower, wc.prefix) {
|
||||||
|
return wc.pricing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
||||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||||
cache, err := s.loadCache(ctx)
|
cache, err := s.loadCache(ctx)
|
||||||
@@ -245,7 +291,11 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
|||||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||||
pricing, ok := cache.pricingByGroupModel[key]
|
pricing, ok := cache.pricingByGroupModel[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
// 精确查找失败,尝试通配符匹配
|
||||||
|
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model))
|
||||||
|
if pricing == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cp := pricing.Clone()
|
cp := pricing.Clone()
|
||||||
@@ -302,7 +352,14 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
|
|||||||
platform := cache.groupPlatform[groupID]
|
platform := cache.groupPlatform[groupID]
|
||||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||||
_, exists := cache.pricingByGroupModel[key]
|
_, exists := cache.pricingByGroupModel[key]
|
||||||
return !exists
|
if exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 精确查找失败,尝试通配符匹配
|
||||||
|
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- CRUD ---
|
// --- CRUD ---
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
|||||||
&DeferredService{},
|
&DeferredService{},
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||||
rateRepo,
|
rateRepo,
|
||||||
|
|||||||
@@ -323,6 +323,7 @@ type OpenAIGatewayService struct {
|
|||||||
toolCorrector *CodexToolCorrector
|
toolCorrector *CodexToolCorrector
|
||||||
openaiWSResolver OpenAIWSProtocolResolver
|
openaiWSResolver OpenAIWSProtocolResolver
|
||||||
resolver *ModelPricingResolver
|
resolver *ModelPricingResolver
|
||||||
|
channelService *ChannelService
|
||||||
|
|
||||||
openaiWSPoolOnce sync.Once
|
openaiWSPoolOnce sync.Once
|
||||||
openaiWSStateStoreOnce sync.Once
|
openaiWSStateStoreOnce sync.Once
|
||||||
@@ -359,6 +360,7 @@ func NewOpenAIGatewayService(
|
|||||||
deferredService *DeferredService,
|
deferredService *DeferredService,
|
||||||
openAITokenProvider *OpenAITokenProvider,
|
openAITokenProvider *OpenAITokenProvider,
|
||||||
resolver *ModelPricingResolver,
|
resolver *ModelPricingResolver,
|
||||||
|
channelService *ChannelService,
|
||||||
) *OpenAIGatewayService {
|
) *OpenAIGatewayService {
|
||||||
svc := &OpenAIGatewayService{
|
svc := &OpenAIGatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -387,6 +389,7 @@ func NewOpenAIGatewayService(
|
|||||||
toolCorrector: NewCodexToolCorrector(),
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
|
channelService: channelService,
|
||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||||
}
|
}
|
||||||
@@ -394,6 +397,22 @@ func NewOpenAIGatewayService(
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
|
||||||
|
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||||
|
if s.channelService == nil {
|
||||||
|
return ChannelMappingResult{MappedModel: model}
|
||||||
|
}
|
||||||
|
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
|
||||||
|
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||||
|
if s.channelService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||||
if s != nil && s.codexSnapshotThrottle != nil {
|
if s != nil && s.codexSnapshotThrottle != nil {
|
||||||
return s.codexSnapshotThrottle
|
return s.codexSnapshotThrottle
|
||||||
@@ -4113,6 +4132,10 @@ type OpenAIRecordUsageInput struct {
|
|||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
RequestPayloadHash string
|
RequestPayloadHash string
|
||||||
APIKeyService APIKeyQuotaUpdater
|
APIKeyService APIKeyQuotaUpdater
|
||||||
|
ChannelID int64
|
||||||
|
OriginalModel string
|
||||||
|
BillingModelSource string
|
||||||
|
ModelMappingChain string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsage records usage and deducts balance
|
// RecordUsage records usage and deducts balance
|
||||||
@@ -4158,6 +4181,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
var err error
|
var err error
|
||||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||||
|
if result.BillingModel != "" {
|
||||||
|
billingModel = strings.TrimSpace(result.BillingModel)
|
||||||
|
}
|
||||||
|
if input.BillingModelSource == "requested" && input.OriginalModel != "" {
|
||||||
|
billingModel = input.OriginalModel
|
||||||
|
}
|
||||||
serviceTier := ""
|
serviceTier := ""
|
||||||
if result.ServiceTier != nil {
|
if result.ServiceTier != nil {
|
||||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||||
@@ -4223,6 +4252,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
FirstTokenMs: result.FirstTokenMs,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
// 设置渠道信息
|
||||||
|
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
|
||||||
|
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
|
||||||
// 设置计费模式
|
// 设置计费模式
|
||||||
if cost != nil && cost.BillingMode != "" {
|
if cost != nil && cost.BillingMode != "" {
|
||||||
billingMode := cost.BillingMode
|
billingMode := cost.BillingMode
|
||||||
|
|||||||
@@ -616,6 +616,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||||
|
|||||||
@@ -1789,6 +1789,7 @@ export default {
|
|||||||
noTiersYet: 'No tiers yet. Click add to configure per-request pricing.',
|
noTiersYet: 'No tiers yet. Click add to configure per-request pricing.',
|
||||||
noPricingRules: 'No pricing rules yet. Click "Add" to create one.',
|
noPricingRules: 'No pricing rules yet. Click "Add" to create one.',
|
||||||
perRequestPrice: 'Price per Request',
|
perRequestPrice: 'Price per Request',
|
||||||
|
perRequestPriceRequired: 'Per-request price or billing tiers required for per-request/image billing mode',
|
||||||
tierLabel: 'Tier',
|
tierLabel: 'Tier',
|
||||||
resolution: 'Resolution',
|
resolution: 'Resolution',
|
||||||
modelMapping: 'Model Mapping',
|
modelMapping: 'Model Mapping',
|
||||||
|
|||||||
@@ -1869,6 +1869,7 @@ export default {
|
|||||||
noTiersYet: '暂无层级,点击添加配置按次计费价格',
|
noTiersYet: '暂无层级,点击添加配置按次计费价格',
|
||||||
noPricingRules: '暂无定价规则,点击"添加"创建',
|
noPricingRules: '暂无定价规则,点击"添加"创建',
|
||||||
perRequestPrice: '单次价格',
|
perRequestPrice: '单次价格',
|
||||||
|
perRequestPriceRequired: '按次/图片计费模式必须设置默认价格或至少一个计费层级',
|
||||||
tierLabel: '层级',
|
tierLabel: '层级',
|
||||||
resolution: '分辨率',
|
resolution: '分辨率',
|
||||||
modelMapping: '模型映射',
|
modelMapping: '模型映射',
|
||||||
|
|||||||
@@ -876,6 +876,19 @@ async function handleSubmit() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 校验 per_request/image 模式必须有价格
|
||||||
|
for (const section of form.platforms) {
|
||||||
|
for (const entry of section.model_pricing) {
|
||||||
|
if (entry.models.length === 0) continue
|
||||||
|
if ((entry.billing_mode === 'per_request' || entry.billing_mode === 'image') &&
|
||||||
|
(entry.per_request_price == null || entry.per_request_price === '') &&
|
||||||
|
(!entry.intervals || entry.intervals.length === 0)) {
|
||||||
|
appStore.showError(t('admin.channels.perRequestPriceRequired', '按次/图片计费模式必须设置默认价格或至少一个计费层级'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
||||||
console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
|
console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
|
||||||
|
|
||||||
|
|||||||
@@ -181,6 +181,13 @@
|
|||||||
</span>
|
</span>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
|
<template #cell-billing_mode="{ row }">
|
||||||
|
<span class="inline-flex items-center rounded px-1.5 py-0.5 text-xs font-medium"
|
||||||
|
:class="getBillingModeBadgeClass(row.billing_mode)">
|
||||||
|
{{ getBillingModeLabel(row.billing_mode) }}
|
||||||
|
</span>
|
||||||
|
</template>
|
||||||
|
|
||||||
<template #cell-tokens="{ row }">
|
<template #cell-tokens="{ row }">
|
||||||
<!-- 图片生成请求 -->
|
<!-- 图片生成请求 -->
|
||||||
<div v-if="row.image_count > 0" class="flex items-center gap-1.5">
|
<div v-if="row.image_count > 0" class="flex items-center gap-1.5">
|
||||||
@@ -525,6 +532,7 @@ const columns = computed<Column[]>(() => [
|
|||||||
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
|
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
|
||||||
{ key: 'endpoint', label: t('usage.endpoint'), sortable: false },
|
{ key: 'endpoint', label: t('usage.endpoint'), sortable: false },
|
||||||
{ key: 'stream', label: t('usage.type'), sortable: false },
|
{ key: 'stream', label: t('usage.type'), sortable: false },
|
||||||
|
{ key: 'billing_mode', label: t('admin.usage.billingMode'), sortable: false },
|
||||||
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
||||||
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
||||||
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
||||||
@@ -615,6 +623,18 @@ const getRequestTypeBadgeClass = (log: UsageLog): string => {
|
|||||||
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
|
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getBillingModeLabel = (mode: string | null | undefined): string => {
|
||||||
|
if (mode === 'per_request') return t('admin.usage.billingModePerRequest')
|
||||||
|
if (mode === 'image') return t('admin.usage.billingModeImage')
|
||||||
|
return t('admin.usage.billingModeToken')
|
||||||
|
}
|
||||||
|
|
||||||
|
const getBillingModeBadgeClass = (mode: string | null | undefined): string => {
|
||||||
|
if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-200'
|
||||||
|
if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-200'
|
||||||
|
return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300'
|
||||||
|
}
|
||||||
|
|
||||||
const getRequestTypeExportText = (log: UsageLog): string => {
|
const getRequestTypeExportText = (log: UsageLog): string => {
|
||||||
const requestType = resolveUsageRequestType(log)
|
const requestType = resolveUsageRequestType(log)
|
||||||
if (requestType === 'ws_v2') return 'WS'
|
if (requestType === 'ws_v2') return 'WS'
|
||||||
@@ -804,6 +824,7 @@ const exportToCSV = async () => {
|
|||||||
'Reasoning Effort',
|
'Reasoning Effort',
|
||||||
'Inbound Endpoint',
|
'Inbound Endpoint',
|
||||||
'Type',
|
'Type',
|
||||||
|
'Billing Mode',
|
||||||
'Input Tokens',
|
'Input Tokens',
|
||||||
'Output Tokens',
|
'Output Tokens',
|
||||||
'Cache Read Tokens',
|
'Cache Read Tokens',
|
||||||
@@ -822,6 +843,7 @@ const exportToCSV = async () => {
|
|||||||
formatReasoningEffort(log.reasoning_effort),
|
formatReasoningEffort(log.reasoning_effort),
|
||||||
log.inbound_endpoint || '',
|
log.inbound_endpoint || '',
|
||||||
getRequestTypeExportText(log),
|
getRequestTypeExportText(log),
|
||||||
|
getBillingModeLabel(log.billing_mode),
|
||||||
log.input_tokens,
|
log.input_tokens,
|
||||||
log.output_tokens,
|
log.output_tokens,
|
||||||
log.cache_read_tokens,
|
log.cache_read_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user