mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-10 01:54:46 +08:00
Merge branch 'main' into release/custom-0.1.91
# Conflicts: # frontend/src/components/admin/account/AccountActionMenu.vue # frontend/src/views/admin/AccountsView.vue
This commit is contained in:
@@ -102,6 +102,7 @@ type CreateAccountRequest struct {
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -120,6 +121,7 @@ type UpdateAccountRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
@@ -135,6 +137,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
@@ -506,6 +509,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
@@ -575,6 +579,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
@@ -1101,6 +1106,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.RateMultiplier != nil ||
|
||||
req.LoadFactor != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
@@ -1119,6 +1125,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
@@ -1132,6 +1139,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -1328,6 +1341,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// ResetQuota handles resetting account quota usage
|
||||
// POST /api/v1/admin/accounts/:id/reset-quota
|
||||
func (h *AccountHandler) ResetQuota(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
|
||||
response.InternalError(c, "Failed to reset account quota: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
// GET /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
|
||||
@@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
@@ -46,9 +46,10 @@ type CreateGroupRequest struct {
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
@@ -81,9 +82,10 @@ type UpdateGroupRequest struct {
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
@@ -201,6 +203,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
@@ -252,6 +255,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
|
||||
@@ -122,13 +122,14 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
@@ -183,6 +184,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
LoadFactor: a.LoadFactor,
|
||||
Priority: a.Priority,
|
||||
RateMultiplier: a.BillingRateMultiplier(),
|
||||
Status: a.Status,
|
||||
@@ -248,6 +250,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -111,6 +111,8 @@ type AdminGroup struct {
|
||||
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
// Claude usage 模拟开关(仅管理员可见)
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
@@ -131,6 +133,7 @@ type Account struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
@@ -185,6 +188,10 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -439,6 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
@@ -630,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// ===== 用户消息串行队列 END =====
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
c.Set("parsed_request", parsedReq)
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
@@ -734,6 +736,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
|
||||
@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
|
||||
|
||||
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
|
||||
|
||||
@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
|
||||
@@ -18,6 +18,9 @@ const (
|
||||
BlockTypeFunction
|
||||
)
|
||||
|
||||
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||
type UsageMapHook func(usageMap map[string]any)
|
||||
|
||||
// StreamingProcessor 流式响应处理器
|
||||
type StreamingProcessor struct {
|
||||
blockType BlockType
|
||||
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
||||
originalModel string
|
||||
webSearchQueries []string
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
usageMapHook UsageMapHook
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||
p.usageMapHook = fn
|
||||
}
|
||||
|
||||
func usageToMap(u ClaudeUsage) map[string]any {
|
||||
m := map[string]any{
|
||||
"input_tokens": u.InputTokens,
|
||||
"output_tokens": u.OutputTokens,
|
||||
}
|
||||
if u.CacheCreationInputTokens > 0 {
|
||||
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||
}
|
||||
if u.CacheReadInputTokens > 0 {
|
||||
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
@@ -158,6 +181,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
responseID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
var usageValue any = usage
|
||||
if p.usageMapHook != nil {
|
||||
usageMap := usageToMap(usage)
|
||||
p.usageMapHook(usageMap)
|
||||
usageValue = usageMap
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"id": responseID,
|
||||
"type": "message",
|
||||
@@ -166,7 +196,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
"model": p.originalModel,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
"usage": usageValue,
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
@@ -477,13 +507,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
var usageValue any = usage
|
||||
if p.usageMapHook != nil {
|
||||
usageMap := usageToMap(usage)
|
||||
p.usageMapHook(usageMap)
|
||||
usageValue = usageMap
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": usage,
|
||||
"usage": usageValue,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
@@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
} else {
|
||||
builder.ClearLoadFactor()
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -1223,6 +1231,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.RateMultiplier)
|
||||
idx++
|
||||
}
|
||||
if updates.LoadFactor != nil {
|
||||
if *updates.LoadFactor <= 0 {
|
||||
setClauses = append(setClauses, "load_factor = NULL")
|
||||
} else {
|
||||
setClauses = append(setClauses, "load_factor = $"+itoa(idx))
|
||||
args = append(args, *updates.LoadFactor)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if updates.Status != nil {
|
||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||
args = append(args, *updates.Status)
|
||||
@@ -1545,6 +1562,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
LoadFactor: m.LoadFactor,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
@@ -1657,3 +1675,60 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 重置配额后触发调度快照刷新,使账号重新参与调度
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -164,6 +164,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSimulateClaudeMaxEnabled,
|
||||
group.FieldSupportedModelScopes,
|
||||
)
|
||||
}).
|
||||
@@ -617,6 +618,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
CreatedAt: g.CreatedAt,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
|
||||
@@ -59,7 +59,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -125,7 +126,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
@@ -1870,7 +1870,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
||||
query := `
|
||||
SELECT
|
||||
COALESCE(ul.group_id, 0) as group_id,
|
||||
COALESCE(g.name, '') as group_name,
|
||||
COALESCE(g.name, '(无分组)') as group_name,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||
|
||||
@@ -1096,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
return int64(len(ids)), nil
|
||||
|
||||
@@ -252,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
|
||||
@@ -28,6 +28,7 @@ type Account struct {
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) EffectiveLoadFactor() int {
|
||||
if a == nil {
|
||||
return 1
|
||||
}
|
||||
if a.LoadFactor != nil && *a.LoadFactor > 0 {
|
||||
return *a.LoadFactor
|
||||
}
|
||||
if a.Concurrency > 0 {
|
||||
return a.Concurrency
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
@@ -1117,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
return "5m"
|
||||
}
|
||||
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
|
||||
42
backend/internal/service/account_load_factor_test.go
Normal file
42
backend/internal/service/account_load_factor_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtr(20)}
|
||||
require.Equal(t, 20, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtr(0)}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 3, LoadFactor: intPtr(-1)}
|
||||
require.Equal(t, 3, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0, LoadFactor: intPtr(0)}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
@@ -68,6 +68,10 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
@@ -78,6 +82,7 @@ type AccountBulkUpdate struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
|
||||
@@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A
|
||||
panic("unexpected BulkUpdate call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 false(账号不存在)
|
||||
|
||||
@@ -180,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
return s.routeAntigravityTest(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformSora {
|
||||
@@ -1177,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
||||
return soraerror.TruncateBody(body, max)
|
||||
}
|
||||
|
||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
}
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
|
||||
@@ -84,6 +84,7 @@ type AdminService interface {
|
||||
DeleteRedeemCode(ctx context.Context, id int64) error
|
||||
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
ResetAccountQuota(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
@@ -137,9 +138,10 @@ type CreateGroupInput struct {
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
SimulateClaudeMaxEnabled *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
@@ -173,9 +175,10 @@ type UpdateGroupInput struct {
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
SimulateClaudeMaxEnabled *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
@@ -195,6 +198,7 @@ type CreateAccountInput struct {
|
||||
Concurrency int
|
||||
Priority int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
@@ -215,6 +219,7 @@ type UpdateAccountInput struct {
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
@@ -230,6 +235,7 @@ type BulkUpdateAccountsInput struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
@@ -353,6 +359,10 @@ type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
type proxyQualityTarget struct {
|
||||
Target string
|
||||
URL string
|
||||
@@ -428,10 +438,6 @@ type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
@@ -847,6 +853,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
if input.MCPXMLInject != nil {
|
||||
mcpXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
simulateClaudeMaxEnabled := false
|
||||
if input.SimulateClaudeMaxEnabled != nil {
|
||||
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
||||
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
||||
}
|
||||
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||
var accountIDsToCopy []int64
|
||||
@@ -903,6 +916,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||
ModelRouting: input.ModelRouting,
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
@@ -1112,6 +1126,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.MCPXMLInject != nil {
|
||||
group.MCPXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
if input.SimulateClaudeMaxEnabled != nil {
|
||||
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
||||
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
||||
}
|
||||
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
||||
}
|
||||
if group.Platform != PlatformAnthropic {
|
||||
group.SimulateClaudeMaxEnabled = false
|
||||
}
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
if input.SupportedModelScopes != nil {
|
||||
@@ -1413,6 +1436,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil && *input.LoadFactor > 0 {
|
||||
account.LoadFactor = input.LoadFactor
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1458,6 +1484,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// 保留 quota_used,防止编辑账号时意外重置配额用量
|
||||
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
|
||||
input.Extra["quota_used"] = oldQuotaUsed
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
@@ -1483,6 +1513,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil {
|
||||
if *input.LoadFactor <= 0 {
|
||||
account.LoadFactor = nil // 0 或负数表示清除
|
||||
} else {
|
||||
account.LoadFactor = input.LoadFactor
|
||||
}
|
||||
}
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
@@ -1616,6 +1653,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.RateMultiplier != nil {
|
||||
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil {
|
||||
repoUpdates.LoadFactor = input.LoadFactor
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
@@ -2439,3 +2479,7 @@ func (e *MixedChannelError) Error() string {
|
||||
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
|
||||
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||
s.getByIDsCalled = true
|
||||
s.getByIDsIDs = append([]int64{}, ids...)
|
||||
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
|
||||
@@ -785,3 +785,57 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
enabled := true
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "openai-group",
|
||||
Platform: PlatformOpenAI,
|
||||
SimulateClaudeMaxEnabled: &enabled,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "openai-group",
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
enabled := true
|
||||
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
SimulateClaudeMaxEnabled: &enabled,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "anthropic-group",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
|
||||
}
|
||||
|
||||
@@ -1599,7 +1599,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
var clientDisconnect bool
|
||||
if claudeReq.Stream {
|
||||
// 客户端要求流式,直接透传转换
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
@@ -1609,7 +1609,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
clientDisconnect = streamRes.clientDisconnect
|
||||
} else {
|
||||
// 客户端要求非流式,收集流式响应后转换返回
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
@@ -1618,6 +1618,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
}
|
||||
|
||||
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
|
||||
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -3415,7 +3418,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
|
||||
|
||||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
@@ -3573,6 +3576,9 @@ returnResponse:
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||
}
|
||||
|
||||
// Claude Max cache billing simulation (non-streaming)
|
||||
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
|
||||
|
||||
c.Data(http.StatusOK, "application/json", claudeResp)
|
||||
|
||||
// 转换为 service.ClaudeUsage
|
||||
@@ -3587,7 +3593,7 @@ returnResponse:
|
||||
}
|
||||
|
||||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -3600,6 +3606,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
|
||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
|
||||
|
||||
var firstTokenMs *int
|
||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
@@ -710,7 +710,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -787,7 +787,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -990,7 +990,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -1014,7 +1014,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
@@ -59,9 +59,10 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
|
||||
@@ -244,6 +244,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
@@ -301,6 +302,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
|
||||
450
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
450
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type claudeMaxCacheBillingOutcome struct {
|
||||
Simulated bool
|
||||
}
|
||||
|
||||
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||
var out claudeMaxCacheBillingOutcome
|
||||
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
|
||||
return out
|
||||
}
|
||||
|
||||
resolvedModel := strings.TrimSpace(model)
|
||||
if resolvedModel == "" && parsed != nil {
|
||||
resolvedModel = strings.TrimSpace(parsed.Model)
|
||||
}
|
||||
|
||||
if hasCacheCreationTokens(*usage) {
|
||||
// Upstream already returned cache creation usage; keep original usage.
|
||||
return out
|
||||
}
|
||||
|
||||
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
|
||||
return out
|
||||
}
|
||||
beforeInputTokens := usage.InputTokens
|
||||
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
|
||||
if out.Simulated {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
||||
resolvedModel,
|
||||
accountID,
|
||||
beforeInputTokens,
|
||||
usage.InputTokens,
|
||||
usage.CacheCreation1hTokens,
|
||||
)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isClaudeFamilyModel(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(normalized, "claude-")
|
||||
}
|
||||
|
||||
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
||||
return false
|
||||
}
|
||||
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
|
||||
}
|
||||
|
||||
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
|
||||
resolvedModel := model
|
||||
if resolvedModel == "" && parsed != nil {
|
||||
resolvedModel = parsed.Model
|
||||
}
|
||||
if !isClaudeFamilyModel(resolvedModel) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasCacheCreationTokens(usage ClaudeUsage) bool {
|
||||
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil {
|
||||
return false
|
||||
}
|
||||
if !shouldApplyClaudeMaxBillingRules(input) {
|
||||
return false
|
||||
}
|
||||
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage.InputTokens <= 0 {
|
||||
return false
|
||||
}
|
||||
if hasCacheCreationTokens(usage) {
|
||||
return false
|
||||
}
|
||||
if !hasClaudeCacheSignals(parsed) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
|
||||
changed = false
|
||||
}
|
||||
}()
|
||||
return projectUsageToClaudeMax1H(usage, parsed)
|
||||
}
|
||||
|
||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if totalWindowTokens <= 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
|
||||
if simulatedInputTokens <= 0 {
|
||||
simulatedInputTokens = 1
|
||||
}
|
||||
if simulatedInputTokens >= totalWindowTokens {
|
||||
simulatedInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
|
||||
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
||||
if usage.InputTokens == simulatedInputTokens &&
|
||||
usage.CacheCreation5mTokens == 0 &&
|
||||
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
||||
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
||||
return false
|
||||
}
|
||||
|
||||
usage.InputTokens = simulatedInputTokens
|
||||
usage.CacheCreation5mTokens = 0
|
||||
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
||||
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
||||
return true
|
||||
}
|
||||
|
||||
type claudeCacheProjection struct {
|
||||
HasBreakpoint bool
|
||||
BreakpointCount int
|
||||
TotalEstimatedTokens int
|
||||
TailEstimatedTokens int
|
||||
}
|
||||
|
||||
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
||||
if totalWindowTokens <= 1 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
projection := analyzeClaudeCacheProjection(parsed)
|
||||
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
totalEstimate := int64(projection.TotalEstimatedTokens)
|
||||
tailEstimate := int64(projection.TailEstimatedTokens)
|
||||
if tailEstimate > totalEstimate {
|
||||
tailEstimate = totalEstimate
|
||||
}
|
||||
|
||||
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
|
||||
if scaled <= 0 {
|
||||
scaled = 1
|
||||
}
|
||||
if scaled >= int64(totalWindowTokens) {
|
||||
scaled = int64(totalWindowTokens - 1)
|
||||
}
|
||||
return int(scaled)
|
||||
}
|
||||
|
||||
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||
return true
|
||||
}
|
||||
return countExplicitCacheBreakpoints(parsed) > 0
|
||||
}
|
||||
|
||||
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
|
||||
if parsed == nil || len(parsed.Body) == 0 {
|
||||
return false
|
||||
}
|
||||
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
|
||||
return strings.EqualFold(cacheType, "ephemeral")
|
||||
}
|
||||
|
||||
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
|
||||
var projection claudeCacheProjection
|
||||
if parsed == nil {
|
||||
return projection
|
||||
}
|
||||
|
||||
total := 0
|
||||
lastBreakpointAt := -1
|
||||
|
||||
switch system := parsed.System.(type) {
|
||||
case string:
|
||||
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
|
||||
case []any:
|
||||
for _, raw := range system {
|
||||
block, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
total += estimateClaudeBlockTokens(block)
|
||||
if hasEphemeralCacheControl(block) {
|
||||
lastBreakpointAt = total
|
||||
projection.BreakpointCount++
|
||||
projection.HasBreakpoint = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rawMsg := range parsed.Messages {
|
||||
total += claudeMaxMessageOverheadTokens
|
||||
msg, ok := rawMsg.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
content, exists := msg["content"]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
|
||||
total += msgTokens
|
||||
if msgBreakCount > 0 {
|
||||
lastBreakpointAt = total - msgTokens + msgLastBreak
|
||||
projection.BreakpointCount += msgBreakCount
|
||||
projection.HasBreakpoint = true
|
||||
}
|
||||
}
|
||||
|
||||
if total <= 0 {
|
||||
total = 1
|
||||
}
|
||||
projection.TotalEstimatedTokens = total
|
||||
|
||||
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
|
||||
tail := total - lastBreakpointAt
|
||||
if tail <= 0 {
|
||||
tail = 1
|
||||
}
|
||||
projection.TailEstimatedTokens = tail
|
||||
return projection
|
||||
}
|
||||
|
||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||
tail := estimateLastUserMessageTokens(parsed)
|
||||
if tail <= 0 {
|
||||
tail = 1
|
||||
}
|
||||
projection.HasBreakpoint = true
|
||||
projection.BreakpointCount = 1
|
||||
projection.TailEstimatedTokens = tail
|
||||
}
|
||||
return projection
|
||||
}
|
||||
|
||||
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
total := 0
|
||||
if system, ok := parsed.System.([]any); ok {
|
||||
for _, raw := range system {
|
||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, rawMsg := range parsed.Messages {
|
||||
msg, ok := rawMsg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msg["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, raw := range content {
|
||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func hasEphemeralCacheControl(block map[string]any) bool {
|
||||
if block == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := block["cache_control"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch cc := raw.(type) {
|
||||
case map[string]any:
|
||||
cacheType, _ := cc["type"].(string)
|
||||
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
|
||||
case map[string]string:
|
||||
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
return estimateClaudeTextTokens(value), -1, 0
|
||||
case []any:
|
||||
total := 0
|
||||
lastBreak := -1
|
||||
breaks := 0
|
||||
for _, raw := range value {
|
||||
block, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
total += estimateClaudeBlockTokens(block)
|
||||
if hasEphemeralCacheControl(block) {
|
||||
lastBreak = total
|
||||
breaks++
|
||||
}
|
||||
}
|
||||
return total, lastBreak, breaks
|
||||
default:
|
||||
return estimateStructuredTokens(value), -1, 0
|
||||
}
|
||||
}
|
||||
|
||||
func estimateClaudeBlockTokens(block map[string]any) int {
|
||||
if block == nil {
|
||||
return claudeMaxUnknownContentTokens
|
||||
}
|
||||
tokens := claudeMaxBlockOverheadTokens
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(text)
|
||||
}
|
||||
case "tool_result":
|
||||
if content, ok := block["content"]; ok {
|
||||
nested, _, _ := estimateClaudeContentTokens(content)
|
||||
tokens += nested
|
||||
}
|
||||
case "tool_use":
|
||||
if name, ok := block["name"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(name)
|
||||
}
|
||||
if input, ok := block["input"]; ok {
|
||||
tokens += estimateStructuredTokens(input)
|
||||
}
|
||||
default:
|
||||
if text, ok := block["text"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(text)
|
||||
} else if content, ok := block["content"]; ok {
|
||||
nested, _, _ := estimateClaudeContentTokens(content)
|
||||
tokens += nested
|
||||
}
|
||||
}
|
||||
if tokens <= claudeMaxBlockOverheadTokens {
|
||||
tokens += claudeMaxUnknownContentTokens
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
|
||||
if parsed == nil || len(parsed.Messages) == 0 {
|
||||
return 0
|
||||
}
|
||||
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
||||
msg, ok := parsed.Messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
||||
continue
|
||||
}
|
||||
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
|
||||
return claudeMaxMessageOverheadTokens + tokens
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func estimateStructuredTokens(v any) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
raw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return claudeMaxUnknownContentTokens
|
||||
}
|
||||
return estimateClaudeTextTokens(string(raw))
|
||||
}
|
||||
|
||||
func estimateClaudeTextTokens(text string) int {
|
||||
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
|
||||
return tokens
|
||||
}
|
||||
return estimateClaudeTextTokensHeuristic(text)
|
||||
}
|
||||
|
||||
func estimateClaudeTextTokensHeuristic(text string) int {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
||||
if normalized == "" {
|
||||
return 0
|
||||
}
|
||||
asciiChars := 0
|
||||
nonASCIIChars := 0
|
||||
for _, r := range normalized {
|
||||
if r <= 127 {
|
||||
asciiChars++
|
||||
} else {
|
||||
nonASCIIChars++
|
||||
}
|
||||
}
|
||||
tokens := nonASCIIChars
|
||||
if asciiChars > 0 {
|
||||
tokens += (asciiChars + 3) / 4
|
||||
}
|
||||
if words := len(strings.Fields(normalized)); words > tokens {
|
||||
tokens = words
|
||||
}
|
||||
if tokens <= 0 {
|
||||
return 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
156
backend/internal/service/claude_max_simulation_test.go
Normal file
156
backend/internal/service/claude_max_simulation_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: 1200,
|
||||
CacheCreationInputTokens: 0,
|
||||
CacheCreation5mTokens: 0,
|
||||
CacheCreation1hTokens: 0,
|
||||
}
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": strings.Repeat("cached context ", 200),
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "summarize quickly",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed := projectUsageToClaudeMax1H(usage, parsed)
|
||||
if !changed {
|
||||
t.Fatalf("expected usage to be projected")
|
||||
}
|
||||
|
||||
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if total != 1200 {
|
||||
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
|
||||
}
|
||||
if usage.CacheCreation5mTokens != 0 {
|
||||
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
|
||||
}
|
||||
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
|
||||
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
|
||||
}
|
||||
if usage.InputTokens > 100 {
|
||||
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
|
||||
}
|
||||
if usage.CacheCreation1hTokens <= 0 {
|
||||
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
|
||||
}
|
||||
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
|
||||
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-opus-4-5",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "build context",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "what is failing now",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||
if got1 != got2 {
|
||||
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
|
||||
group := &Group{
|
||||
Platform: PlatformAnthropic,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
}
|
||||
input := &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 3000,
|
||||
CacheCreationInputTokens: 0,
|
||||
CacheCreation5mTokens: 0,
|
||||
CacheCreation1hTokens: 0,
|
||||
},
|
||||
},
|
||||
ParsedRequest: &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "cached",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "tail",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{Group: group},
|
||||
}
|
||||
|
||||
if !shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=true for claude group with cache signal")
|
||||
}
|
||||
|
||||
input.ParsedRequest = &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "no cache signal"},
|
||||
},
|
||||
}
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=false when request has no cache signal")
|
||||
}
|
||||
|
||||
input.ParsedRequest = &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "cached",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
input.Result.Usage.CacheCreationInputTokens = 100
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=false when cache creation already exists")
|
||||
}
|
||||
}
|
||||
41
backend/internal/service/claude_tokenizer.go
Normal file
41
backend/internal/service/claude_tokenizer.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
tiktoken "github.com/pkoukk/tiktoken-go"
|
||||
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
|
||||
)
|
||||
|
||||
var (
|
||||
claudeTokenizerOnce sync.Once
|
||||
claudeTokenizer *tiktoken.Tiktoken
|
||||
)
|
||||
|
||||
func getClaudeTokenizer() *tiktoken.Tiktoken {
|
||||
claudeTokenizerOnce.Do(func() {
|
||||
// Use offline loader to avoid runtime dictionary download.
|
||||
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
|
||||
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
|
||||
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
||||
if err != nil {
|
||||
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
||||
}
|
||||
if err == nil {
|
||||
claudeTokenizer = enc
|
||||
}
|
||||
})
|
||||
return claudeTokenizer
|
||||
}
|
||||
|
||||
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
|
||||
enc := getClaudeTokenizer()
|
||||
if enc == nil {
|
||||
return 0, false
|
||||
}
|
||||
tokens := len(enc.EncodeOrdinary(text))
|
||||
if tokens <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return tokens, true
|
||||
}
|
||||
@@ -331,8 +331,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
||||
}()
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
|
||||
// Uses a detached context with timeout to prevent HTTP request cancellation from
|
||||
// causing the entire batch to fail (which would show all concurrency as 0).
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
@@ -344,5 +345,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||
|
||||
// Use a detached context so that a cancelled HTTP request doesn't cause
|
||||
// the Redis pipeline to fail and return all-zero concurrency counts.
|
||||
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, ok, "value should be a bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
|
||||
196
backend/internal/service/gateway_claude_max_response_helpers.go
Normal file
196
backend/internal/service/gateway_claude_max_response_helpers.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type claudeMaxResponseRewriteContext struct {
|
||||
Parsed *ParsedRequest
|
||||
Group *Group
|
||||
}
|
||||
|
||||
type claudeMaxResponseRewriteContextKeyType struct{}
|
||||
|
||||
var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{}
|
||||
|
||||
func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
value := claudeMaxResponseRewriteContext{
|
||||
Parsed: parsed,
|
||||
Group: claudeMaxGroupFromGinContext(c),
|
||||
}
|
||||
return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value)
|
||||
}
|
||||
|
||||
func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext {
|
||||
if ctx == nil {
|
||||
return claudeMaxResponseRewriteContext{}
|
||||
}
|
||||
value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext)
|
||||
return value
|
||||
}
|
||||
|
||||
func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
raw, exists := c.Get("api_key")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
apiKey, ok := raw.(*APIKey)
|
||||
if !ok || apiKey == nil {
|
||||
return nil
|
||||
}
|
||||
return apiKey.Group
|
||||
}
|
||||
|
||||
func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
raw, exists := c.Get("parsed_request")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
parsed, _ := raw.(*ParsedRequest)
|
||||
return parsed
|
||||
}
|
||||
|
||||
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||
var out claudeMaxCacheBillingOutcome
|
||||
if usage == nil {
|
||||
return out
|
||||
}
|
||||
rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx)
|
||||
return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID)
|
||||
}
|
||||
|
||||
func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||
var out claudeMaxCacheBillingOutcome
|
||||
if usageObj == nil {
|
||||
return out
|
||||
}
|
||||
usage := claudeUsageFromJSONMap(usageObj)
|
||||
out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID)
|
||||
if out.Simulated {
|
||||
rewriteClaudeUsageJSONMap(usageObj, usage)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte {
|
||||
updated := body
|
||||
var err error
|
||||
|
||||
updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage {
|
||||
var usage ClaudeUsage
|
||||
if usageObj == nil {
|
||||
return usage
|
||||
}
|
||||
|
||||
usage.InputTokens = usageIntFromAny(usageObj["input_tokens"])
|
||||
usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"])
|
||||
usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"])
|
||||
usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"])
|
||||
|
||||
if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok {
|
||||
usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"])
|
||||
usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"])
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) {
|
||||
if usageObj == nil {
|
||||
return
|
||||
}
|
||||
usageObj["input_tokens"] = usage.InputTokens
|
||||
usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
|
||||
|
||||
ccObj, _ := usageObj["cache_creation"].(map[string]any)
|
||||
if ccObj == nil {
|
||||
ccObj = make(map[string]any, 2)
|
||||
usageObj["cache_creation"] = ccObj
|
||||
}
|
||||
ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens
|
||||
ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens
|
||||
}
|
||||
|
||||
func usageIntFromAny(v any) int {
|
||||
switch value := v.(type) {
|
||||
case int:
|
||||
return value
|
||||
case int64:
|
||||
return int(value)
|
||||
case float64:
|
||||
return int(value)
|
||||
case json.Number:
|
||||
if n, err := value.Int64(); err == nil {
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。
|
||||
func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) {
|
||||
group := claudeMaxGroupFromGinContext(c)
|
||||
parsed := parsedRequestFromGinContext(c)
|
||||
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
||||
return
|
||||
}
|
||||
processor.SetUsageMapHook(func(usageMap map[string]any) {
|
||||
svcUsage := claudeUsageFromJSONMap(usageMap)
|
||||
outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID)
|
||||
if outcome.Simulated {
|
||||
rewriteClaudeUsageJSONMap(usageMap, svcUsage)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。
|
||||
func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte {
|
||||
group := claudeMaxGroupFromGinContext(c)
|
||||
parsed := parsedRequestFromGinContext(c)
|
||||
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
||||
return claudeResp
|
||||
}
|
||||
svcUsage := &ClaudeUsage{
|
||||
InputTokens: agUsage.InputTokens,
|
||||
OutputTokens: agUsage.OutputTokens,
|
||||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||||
}
|
||||
outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID)
|
||||
if outcome.Simulated {
|
||||
return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage)
|
||||
}
|
||||
return claudeResp
|
||||
}
|
||||
@@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
|
||||
|
||||
|
||||
199
backend/internal/service/gateway_record_usage_claude_max_test.go
Normal file
199
backend/internal/service/gateway_record_usage_claude_max_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type usageLogRepoRecordUsageStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
last *UsageLog
|
||||
inserted bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *usageLogRepoRecordUsageStub) Create(_ context.Context, log *UsageLog) (bool, error) {
|
||||
copied := *log
|
||||
s.last = &copied
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayService {
|
||||
return &GatewayService{
|
||||
usageLogRepo: repo,
|
||||
billingService: NewBillingService(&config.Config{}, nil),
|
||||
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||
deferredService: &DeferredService{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) {
|
||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||
|
||||
groupID := int64(11)
|
||||
input := &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "req-sim-1",
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 160,
|
||||
},
|
||||
},
|
||||
ParsedRequest: &ParsedRequest{
|
||||
Model: "claude-sonnet-4",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "long cached context for prior turns",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "please summarize the logs and provide root cause analysis",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1,
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2},
|
||||
Account: &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "5m",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, repo.last)
|
||||
|
||||
log := repo.last
|
||||
require.Equal(t, 80, log.InputTokens)
|
||||
require.Equal(t, 80, log.CacheCreationTokens)
|
||||
require.Equal(t, 0, log.CacheCreation5mTokens)
|
||||
require.Equal(t, 80, log.CacheCreation1hTokens)
|
||||
require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override")
|
||||
}
|
||||
|
||||
func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) {
|
||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||
|
||||
groupID := int64(12)
|
||||
input := &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "req-sim-2",
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 40,
|
||||
CacheCreationInputTokens: 120,
|
||||
CacheCreation1hTokens: 120,
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 2,
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1,
|
||||
SimulateClaudeMaxEnabled: false,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 3},
|
||||
Account: &Account{
|
||||
ID: 4,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "5m",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, repo.last)
|
||||
|
||||
log := repo.last
|
||||
require.Equal(t, 120, log.CacheCreationTokens)
|
||||
require.Equal(t, 120, log.CacheCreation5mTokens)
|
||||
require.Equal(t, 0, log.CacheCreation1hTokens)
|
||||
require.True(t, log.CacheTTLOverridden)
|
||||
}
|
||||
|
||||
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) {
|
||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||
|
||||
groupID := int64(13)
|
||||
input := &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "req-sim-3",
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
CacheCreationInputTokens: 120,
|
||||
CacheCreation5mTokens: 120,
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 3,
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 4},
|
||||
Account: &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "5m",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, repo.last)
|
||||
|
||||
log := repo.last
|
||||
require.Equal(t, 20, log.InputTokens)
|
||||
require.Equal(t, 120, log.CacheCreation5mTokens)
|
||||
require.Equal(t, 0, log.CacheCreation1hTokens)
|
||||
require.Equal(t, 120, log.CacheCreationTokens)
|
||||
require.False(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should skip account ttl override")
|
||||
}
|
||||
170
backend/internal/service/gateway_response_usage_sync_test.go
Normal file
170
backend/internal/service/gateway_response_usage_sync_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "5m",
|
||||
},
|
||||
}
|
||||
group := &Group{
|
||||
ID: 99,
|
||||
Platform: PlatformAnthropic,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
}
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "long cached context",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "new user question",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: ioNopCloserBytes(upstreamBody),
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
|
||||
c.Set("api_key", &APIKey{Group: group})
|
||||
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
|
||||
|
||||
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
|
||||
var rendered struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered))
|
||||
rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int())
|
||||
rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int())
|
||||
|
||||
require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens)
|
||||
require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens)
|
||||
require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens)
|
||||
require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens)
|
||||
require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens)
|
||||
require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens)
|
||||
|
||||
require.Greater(t, usage.CacheCreation1hTokens, 0)
|
||||
require.Equal(t, 0, usage.CacheCreation5mTokens)
|
||||
require.Less(t, usage.InputTokens, 120)
|
||||
}
|
||||
|
||||
func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "5m",
|
||||
},
|
||||
}
|
||||
group := &Group{
|
||||
ID: 100,
|
||||
Platform: PlatformAnthropic,
|
||||
SimulateClaudeMaxEnabled: false,
|
||||
}
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "long cached context",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "new user question",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: ioNopCloserBytes(upstreamBody),
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
|
||||
c.Set("api_key", &APIKey{Group: group})
|
||||
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
|
||||
|
||||
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
|
||||
require.Equal(t, 120, usage.InputTokens)
|
||||
require.Equal(t, 0, usage.CacheCreationInputTokens)
|
||||
require.Equal(t, 0, usage.CacheCreation5mTokens)
|
||||
require.Equal(t, 0, usage.CacheCreation1hTokens)
|
||||
}
|
||||
|
||||
func ioNopCloserBytes(b []byte) *readCloserFromBytes {
|
||||
return &readCloserFromBytes{Reader: bytes.NewReader(b)}
|
||||
}
|
||||
|
||||
type readCloserFromBytes struct {
|
||||
*bytes.Reader
|
||||
}
|
||||
|
||||
func (r *readCloserFromBytes) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -56,6 +56,12 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeMaxMessageOverheadTokens = 3
|
||||
claudeMaxBlockOverheadTokens = 1
|
||||
claudeMaxUnknownContentTokens = 4
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@@ -1228,6 +1234,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
|
||||
continue
|
||||
}
|
||||
// 配额检查
|
||||
if !s.isAccountSchedulableForQuota(account) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
|
||||
filteredWindowCost++
|
||||
@@ -1260,6 +1270,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
||||
@@ -1311,7 +1322,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, acc := range routingCandidates {
|
||||
routingLoads = append(routingLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
|
||||
@@ -1416,6 +1427,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(account) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
|
||||
@@ -1480,6 +1492,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
// 配额检查
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
@@ -1499,7 +1515,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2113,6 +2129,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
|
||||
// 仅适用于配置了 quota_limit 的 apikey 类型账号
|
||||
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
return true
|
||||
}
|
||||
return !account.IsQuotaExceeded()
|
||||
}
|
||||
|
||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示可调度,false 表示不可调度
|
||||
@@ -2590,7 +2615,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2644,6 +2669,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2700,7 +2728,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2743,6 +2771,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2818,7 +2849,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
@@ -2874,6 +2905,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2930,7 +2964,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@@ -2975,6 +3009,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -4317,6 +4354,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 处理正常响应
|
||||
ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
|
||||
|
||||
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
||||
if parsed.OnUpstreamAccepted != nil {
|
||||
@@ -5773,6 +5811,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
skipAccountTTLOverride := false
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
|
||||
@@ -5833,17 +5872,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
eventChanged = reconcileCachedTokens(u) || eventChanged
|
||||
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
|
||||
if claudeMaxOutcome.Simulated {
|
||||
skipAccountTTLOverride = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if eventType == "message_delta" {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
eventChanged = reconcileCachedTokens(u) || eventChanged
|
||||
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
|
||||
if claudeMaxOutcome.Simulated {
|
||||
skipAccountTTLOverride = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
@@ -6253,8 +6300,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
}
|
||||
|
||||
claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
|
||||
if claudeMaxOutcome.Simulated {
|
||||
body = rewriteClaudeUsageJSONBytes(body, response.Usage)
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||
@@ -6363,6 +6415,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ParsedRequest *ParsedRequest
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
@@ -6379,6 +6432,89 @@ type APIKeyQuotaUpdater interface {
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
// postUsageBillingParams 统一扣费所需的参数
|
||||
type postUsageBillingParams struct {
|
||||
Cost *CostBreakdown
|
||||
User *User
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||
// - 订阅/余额扣费
|
||||
// - API Key 配额更新
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 更新账号最近使用时间
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
}
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
deferredService *DeferredService
|
||||
}
|
||||
|
||||
func (s *GatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
@@ -6396,9 +6532,19 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Claude Max cache billing policy (group-level):
|
||||
// - GatewayService 路径: Forward 已改写 usage(含 cache tokens)→ apply 见到 cache tokens 跳过 → simulatedClaudeMax=true(通过第二条件)
|
||||
// - Antigravity 路径: Forward 中 hook 改写了客户端 SSE,但 ForwardResult.Usage 是原始值 → apply 实际执行模拟 → simulatedClaudeMax=true
|
||||
var apiKeyGroup *Group
|
||||
if apiKey != nil {
|
||||
apiKeyGroup = apiKey.Group
|
||||
}
|
||||
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, input.ParsedRequest, apiKeyGroup, result.Model, account.ID)
|
||||
simulatedClaudeMax := claudeMaxOutcome.Simulated ||
|
||||
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, input.ParsedRequest) && hasCacheCreationTokens(result.Usage))
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
@@ -6542,45 +6688,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// 更新 API Key 配额(如果设置了配额限制)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6740,44 +6862,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
// API Key 独立配额扣费
|
||||
if input.APIKeyService != nil && apiKey.Quota > 0 {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
|
||||
@@ -50,6 +50,9 @@ type Group struct {
|
||||
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool
|
||||
|
||||
// Claude usage 模拟开关:将无写缓存 usage 模拟为 claude-max 风格
|
||||
SimulateClaudeMaxEnabled bool
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
// 可选值: claude, gemini_text, gemini_image
|
||||
SupportedModelScopes []string
|
||||
|
||||
@@ -590,7 +590,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
filtered = append(filtered, account)
|
||||
loadReq = append(loadReq, AccountWithConcurrency{
|
||||
ID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
MaxConcurrency: account.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
|
||||
@@ -319,6 +319,16 @@ func NewOpenAIGatewayService(
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
}
|
||||
}
|
||||
|
||||
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
|
||||
// 应在应用优雅关闭时调用。
|
||||
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
|
||||
@@ -1242,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3474,37 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// Update API key quota if applicable (only for balance mode with quota set)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -864,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool {
|
||||
strings.Contains(message, "unexpected eof") ||
|
||||
strings.Contains(message, "use of closed network connection") ||
|
||||
strings.Contains(message, "connection reset by peer") ||
|
||||
strings.Contains(message, "broken pipe")
|
||||
strings.Contains(message, "broken pipe") ||
|
||||
strings.Contains(message, "an established connection was aborted")
|
||||
}
|
||||
|
||||
func classifyOpenAIWSReadFallbackReason(err error) string {
|
||||
|
||||
@@ -64,8 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev {
|
||||
unique[acc.ID] = acc.Concurrency
|
||||
lf := acc.EffectiveLoadFactor()
|
||||
if prev, ok := unique[acc.ID]; !ok || lf > prev {
|
||||
unique[acc.ID] = lf
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -389,13 +389,9 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
maxConc := acc.Concurrency
|
||||
if maxConc < 0 {
|
||||
maxConc = 0
|
||||
}
|
||||
batch = append(batch, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: maxConc,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
if len(batch) == 0 {
|
||||
|
||||
@@ -34,9 +34,10 @@ func TestCalculateProgress_BasicFields(t *testing.T) {
|
||||
assert.Equal(t, int64(100), progress.ID)
|
||||
assert.Equal(t, "Premium", progress.GroupName)
|
||||
assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt)
|
||||
assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天
|
||||
assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil")
|
||||
assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil")
|
||||
assert.GreaterOrEqual(t, progress.ExpiresInDays, 29)
|
||||
assert.LessOrEqual(t, progress.ExpiresInDays, 30)
|
||||
assert.Nil(t, progress.Daily)
|
||||
assert.Nil(t, progress.Weekly)
|
||||
assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user