mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-02 22:42:14 +08:00
Merge remote-tracking branch 'origin/main' into fix/enc_coot
# Conflicts: # backend/internal/service/openai_gateway_service.go
This commit is contained in:
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
@@ -163,9 +164,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
|
||||
@@ -7,7 +7,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -66,7 +66,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.1 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
|
||||
@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
|
||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
|
||||
@@ -27,10 +27,12 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -113,3 +115,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||
var DefaultBedrockModelMapping = map[string]string{
|
||||
// Claude Opus
|
||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
// Claude Sonnet
|
||||
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
// Claude Haiku
|
||||
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
|
||||
@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func parseRankingLimit(raw string) int {
|
||||
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || limit <= 0 {
|
||||
return 12
|
||||
}
|
||||
if limit > 50 {
|
||||
return 50
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||
// GET /api/v1/admin/dashboard/users-ranking
|
||||
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Limit int `json:"limit"`
|
||||
}{
|
||||
Start: startTime.UTC().Format(time.RFC3339),
|
||||
End: endTime.UTC().Format(time.RFC3339),
|
||||
Limit: limit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user spending ranking")
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
|
||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
rankingLimit int
|
||||
ranking []usagestats.UserSpendingRankingItem
|
||||
rankingTotal float64
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
limit int,
|
||||
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
s.rankingLimit = limit
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
repo := &dashboardUsageRepoCapture{
|
||||
ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||
},
|
||||
rankingTotal: 88.8,
|
||||
}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
}
|
||||
|
||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||
if req.Type == "" {
|
||||
req.Type = "balance"
|
||||
}
|
||||
|
||||
if req.Type == "subscription" {
|
||||
if req.GroupID == nil {
|
||||
response.BadRequest(c, "group_id is required for subscription type")
|
||||
return
|
||||
}
|
||||
if req.ValidityDays <= 0 {
|
||||
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
|
||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||
// parameter-validation layer that runs before any service call.
|
||||
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: newStubAdminService(),
|
||||
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||
}
|
||||
}
|
||||
|
||||
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||
// status code. For cases that pass validation and proceed into the service layer,
|
||||
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||
code = 0
|
||||
}
|
||||
}()
|
||||
handler.CreateAndRedeem(c)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-default",
|
||||
"value": 10.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"omitting type should default to balance and pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-no-group",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"validity_days": 30,
|
||||
// group_id 缺失
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
validityDays int
|
||||
}{
|
||||
{"zero", 0},
|
||||
{"negative", -1},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-bad-days-" + tc.name,
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": tc.validityDays,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-valid",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 31,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"valid subscription params should pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-no-extras",
|
||||
"type": "balance",
|
||||
"value": 50.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
@@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
|
||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
|
||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
|
||||
@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
|
||||
@@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
|
||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
|
||||
@@ -96,12 +96,28 @@ type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserSpendingRankingItem represents a user spending ranking row.
|
||||
type UserSpendingRankingItem struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||
type UserSpendingRankingResponse struct {
|
||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
|
||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
|
||||
@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.Quota != 0 {
|
||||
create.SetQuota(k.Quota)
|
||||
}
|
||||
if k.QuotaUsed != 0 {
|
||||
create.SetQuotaUsed(k.QuotaUsed)
|
||||
}
|
||||
if k.RateLimit5h != 0 {
|
||||
create.SetRateLimit5h(k.RateLimit5h)
|
||||
}
|
||||
if k.RateLimit1d != 0 {
|
||||
create.SetRateLimit1d(k.RateLimit1d)
|
||||
}
|
||||
if k.RateLimit7d != 0 {
|
||||
create.SetRateLimit7d(k.RateLimit7d)
|
||||
}
|
||||
if k.Usage5h != 0 {
|
||||
create.SetUsage5h(k.Usage5h)
|
||||
}
|
||||
if k.Usage1d != 0 {
|
||||
create.SetUsage1d(k.Usage1d)
|
||||
}
|
||||
if k.Usage7d != 0 {
|
||||
create.SetUsage7d(k.Usage7d)
|
||||
}
|
||||
if k.Window5hStart != nil {
|
||||
create.SetWindow5hStart(*k.Window5hStart)
|
||||
}
|
||||
if k.Window1dStart != nil {
|
||||
create.SetWindow1dStart(*k.Window1dStart)
|
||||
}
|
||||
if k.Window7dStart != nil {
|
||||
create.SetWindow7dStart(*k.Window7dStart)
|
||||
}
|
||||
if k.ExpiresAt != nil {
|
||||
create.SetExpiresAt(*k.ExpiresAt)
|
||||
}
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageBillingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||
return &usageBillingRepository{db: sqlDB}
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||
if cmd == nil {
|
||||
return &service.UsageBillingApplyResult{}, nil
|
||||
}
|
||||
if r == nil || r.db == nil {
|
||||
return nil, errors.New("usage billing repository db is nil")
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
if cmd.RequestID == "" {
|
||||
return nil, service.ErrUsageBillingRequestIDRequired
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !applied {
|
||||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||
}
|
||||
|
||||
result := &service.UsageBillingApplyResult{Applied: true}
|
||||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx = nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||
var id int64
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id
|
||||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var existingFingerprint string
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var archivedFingerprint string
|
||||
err = tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup_archive
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.BalanceCost > 0 {
|
||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.APIKeyQuotaCost > 0 {
|
||||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.APIKeyQuotaExhausted = exhausted
|
||||
}
|
||||
|
||||
if cmd.APIKeyRateLimitCost > 0 {
|
||||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||
updated_at = NOW()
|
||||
FROM groups g
|
||||
WHERE us.id = $2
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
`
|
||||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET balance = balance - $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, amount, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||
var exhausted bool
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
UPDATE api_keys
|
||||
SET quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0
|
||||
AND status = $3
|
||||
AND quota_used < quota
|
||||
AND quota_used + $1 >= quota
|
||||
THEN $4
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, service.ErrAPIKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exhausted, nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, cost, apiKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||
rows, err := tx.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||
Name: "billing",
|
||||
Quota: 1,
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
BalanceCost: 1.25,
|
||||
APIKeyQuotaCost: 1.25,
|
||||
APIKeyRateLimitCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result1)
|
||||
require.True(t, result1.Applied)
|
||||
require.True(t, result1.APIKeyQuotaExhausted)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result2)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||
|
||||
var usage5h float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||
|
||||
var status string
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||
|
||||
var dedupCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||
require.Equal(t, 1, dedupCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
group := mustCreateGroup(t, client, &service.Group{
|
||||
Name: "usage-billing-group-" + uuid.NewString(),
|
||||
Platform: service.PlatformAnthropic,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
GroupID: &group.ID,
|
||||
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||
Name: "billing-sub",
|
||||
})
|
||||
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: 0,
|
||||
SubscriptionID: &subscription.ID,
|
||||
SubscriptionCost: 2.5,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var dailyUsage float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||
Name: "billing-conflict",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 2.50,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||
Name: "billing-account",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
},
|
||||
})
|
||||
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 3.5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||
newRequestID := "dedup-new-" + uuid.NewString()
|
||||
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||
|
||||
_, err := integrationDB.ExecContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||
`,
|
||||
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
var oldCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||
require.Equal(t, 0, oldCount)
|
||||
|
||||
var newCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||
require.Equal(t, 1, newCount)
|
||||
|
||||
var archivedCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||
require.Equal(t, 1, archivedCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||
Name: "billing-archive",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
_, err = integrationDB.ExecContext(ctx, `
|
||||
UPDATE usage_billing_dedup
|
||||
SET created_at = $1
|
||||
WHERE request_id = $2 AND api_key_id = $3
|
||||
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||
|
||||
const total = 16
|
||||
results := make([]bool, total)
|
||||
errs := make([]error, total)
|
||||
logs := make([]*service.UsageLog, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
i := i
|
||||
logs[i] = &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.True(t, results[i])
|
||||
require.NotZero(t, logs[i].ID)
|
||||
}
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||
require.Equal(t, total, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
inserted1, err1 := repo.Create(ctx, log1)
|
||||
inserted2, err2 := repo.Create(ctx, log2)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
require.True(t, inserted1)
|
||||
require.False(t, inserted2)
|
||||
require.Equal(t, log1.ID, log2.ID)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
const total = 8
|
||||
batch := make([]usageLogCreateRequest, 0, total)
|
||||
logs := make([]*service.UsageLog, 0, total)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
logs = append(logs, log)
|
||||
batch = append(batch, usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
})
|
||||
}
|
||||
|
||||
repo.flushCreateBatch(integrationDB, batch)
|
||||
|
||||
insertedCount := 0
|
||||
var firstID int64
|
||||
for idx, req := range batch {
|
||||
res := <-req.resultCh
|
||||
require.NoError(t, res.err)
|
||||
if res.inserted {
|
||||
insertedCount++
|
||||
}
|
||||
require.NotZero(t, logs[idx].ID)
|
||||
if idx == 0 {
|
||||
firstID = logs[idx].ID
|
||||
} else {
|
||||
require.Equal(t, firstID, logs[idx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, insertedCount)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
var count int
|
||||
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||
return err == nil && count == 1
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||
|
||||
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
repo.createBatchCh <- usageLogCreateRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
req := <-repo.createBatchCh
|
||||
require.NotNil(t, req.shared)
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||
|
||||
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||
|
||||
res := <-req.resultCh
|
||||
require.False(t, res.inserted)
|
||||
require.Error(t, res.err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
@@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||
|
||||
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||
WithArgs(start, end, 12).
|
||||
WillReturnRows(rows)
|
||||
|
||||
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
|
||||
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
|
||||
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||
},
|
||||
TotalActualCost: 40.0,
|
||||
}, got)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-batch-no-update",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 1.2,
|
||||
ActualCost: 1.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
|
||||
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||
})
|
||||
|
||||
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageBillingRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
|
||||
@@ -1635,6 +1635,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
logs := r.userLogs[userID]
|
||||
if len(logs) == 0 {
|
||||
|
||||
@@ -192,6 +192,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||
|
||||
@@ -412,6 +412,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
}
|
||||
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||
return nil
|
||||
}
|
||||
if len(rawMapping) == 0 {
|
||||
@@ -764,6 +765,14 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrockAPIKey() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
testModelID = claude.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
// API Key 账号测试连接时也需要应用通配符模型映射。
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
testModelID = account.GetMappedModel(testModelID)
|
||||
}
|
||||
|
||||
// Bedrock accounts use a separate test path
|
||||
if account.IsBedrock() {
|
||||
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
region := bedrockRuntimeRegion(account)
|
||||
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
|
||||
if !ok {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
|
||||
}
|
||||
testModelID = resolvedModelID
|
||||
|
||||
// Set SSE headers (test UI expects SSE)
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
|
||||
bedrockPayload := map[string]any{
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"max_tokens": 256,
|
||||
"temperature": 1,
|
||||
}
|
||||
bedrockBody, _ := json.Marshal(bedrockPayload)
|
||||
|
||||
// Use non-streaming endpoint (response is standard Claude JSON)
|
||||
apiURL := BuildBedrockURL(region, testModelID, false)
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Sign or set auth based on account type
|
||||
if account.IsBedrockAPIKey() {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
} else {
|
||||
signer, err := NewBedrockSignerFromAccount(account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
|
||||
}
|
||||
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Bedrock non-streaming response is standard Claude JSON, extract the text
|
||||
var result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
|
||||
}
|
||||
|
||||
text := ""
|
||||
if len(result.Content) > 0 {
|
||||
text = result.Content[0].Text
|
||||
}
|
||||
if text == "" {
|
||||
text = "(empty response)"
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
@@ -47,6 +47,7 @@ type UsageLogRepository interface {
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
|
||||
607
backend/internal/service/bedrock_request.go
Normal file
607
backend/internal/service/bedrock_request.go
Normal file
@@ -0,0 +1,607 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const defaultBedrockRegion = "us-east-1"
|
||||
|
||||
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
|
||||
|
||||
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
|
||||
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||
func BedrockCrossRegionPrefix(region string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(region, "us-gov"):
|
||||
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
|
||||
case strings.HasPrefix(region, "us-"):
|
||||
return "us"
|
||||
case strings.HasPrefix(region, "eu-"):
|
||||
return "eu"
|
||||
case region == "ap-northeast-1":
|
||||
return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义)
|
||||
case region == "ap-southeast-2":
|
||||
return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义)
|
||||
case strings.HasPrefix(region, "ap-"):
|
||||
return "apac" // 其余亚太区域使用通用 apac 前缀
|
||||
case strings.HasPrefix(region, "ca-"):
|
||||
return "us" // 加拿大区域使用 us 前缀的跨区域推理
|
||||
case strings.HasPrefix(region, "sa-"):
|
||||
return "us" // 南美区域使用 us 前缀的跨区域推理
|
||||
default:
|
||||
return "us"
|
||||
}
|
||||
}
|
||||
|
||||
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
|
||||
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
|
||||
// 特殊值 region="global" 强制使用 global. 前缀
|
||||
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
|
||||
var targetPrefix string
|
||||
if region == "global" {
|
||||
targetPrefix = "global"
|
||||
} else {
|
||||
targetPrefix = BedrockCrossRegionPrefix(region)
|
||||
}
|
||||
|
||||
for _, p := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
if p == targetPrefix+"." {
|
||||
return modelID // 前缀已匹配,无需替换
|
||||
}
|
||||
return targetPrefix + "." + modelID[len(p):]
|
||||
}
|
||||
}
|
||||
|
||||
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
|
||||
return modelID
|
||||
}
|
||||
|
||||
func bedrockRuntimeRegion(account *Account) string {
|
||||
if account == nil {
|
||||
return defaultBedrockRegion
|
||||
}
|
||||
if region := account.GetCredential("aws_region"); region != "" {
|
||||
return region
|
||||
}
|
||||
return defaultBedrockRegion
|
||||
}
|
||||
|
||||
func shouldForceBedrockGlobal(account *Account) bool {
|
||||
return account != nil && account.GetCredential("aws_force_global") == "true"
|
||||
}
|
||||
|
||||
func isRegionalBedrockModelID(modelID string) bool {
|
||||
for _, prefix := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(modelID, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isLikelyBedrockModelID(modelID string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(lower, "arn:") {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range []string{
|
||||
"anthropic.",
|
||||
"amazon.",
|
||||
"meta.",
|
||||
"mistral.",
|
||||
"cohere.",
|
||||
"ai21.",
|
||||
"deepseek.",
|
||||
"stability.",
|
||||
"writer.",
|
||||
"nova.",
|
||||
} {
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return isRegionalBedrockModelID(lower)
|
||||
}
|
||||
|
||||
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return "", false, false
|
||||
}
|
||||
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
|
||||
return mapped, true, true
|
||||
}
|
||||
if isRegionalBedrockModelID(modelID) {
|
||||
return modelID, true, true
|
||||
}
|
||||
if isLikelyBedrockModelID(modelID) {
|
||||
return modelID, false, true
|
||||
}
|
||||
return "", false, false
|
||||
}
|
||||
|
||||
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
|
||||
// It applies account model_mapping first, then default Bedrock aliases, and finally
|
||||
// adjusts Anthropic cross-region prefixes to match the account region.
|
||||
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
mappedModel := account.GetMappedModel(requestedModel)
|
||||
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if shouldAdjustRegion {
|
||||
targetRegion := bedrockRuntimeRegion(account)
|
||||
if shouldForceBedrockGlobal(account) {
|
||||
targetRegion = "global"
|
||||
}
|
||||
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
|
||||
}
|
||||
return modelID, true
|
||||
}
|
||||
|
||||
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
|
||||
// stream=true 时使用 invoke-with-response-stream 端点
|
||||
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
|
||||
func BuildBedrockURL(region, modelID string, stream bool) string {
|
||||
if region == "" {
|
||||
region = defaultBedrockRegion
|
||||
}
|
||||
encodedModelID := url.PathEscape(modelID)
|
||||
// url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"),
|
||||
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
|
||||
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
|
||||
if stream {
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
|
||||
}
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
|
||||
// 1. 注入 anthropic_version
|
||||
// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析)
|
||||
// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config)
|
||||
// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true})
|
||||
// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl)
|
||||
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
|
||||
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
|
||||
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
// 注入 anthropic_version(Bedrock 要求)
|
||||
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_version: %w", err)
|
||||
}
|
||||
|
||||
// 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
|
||||
// 1. 从客户端 anthropic-beta header 解析
|
||||
// 2. 根据请求体内容自动补齐必要的 beta token
|
||||
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
|
||||
if len(betaTokens) > 0 {
|
||||
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 移除 model 字段(Bedrock 通过 URL 指定模型)
|
||||
body, err = sjson.DeleteBytes(body, "model")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove model field: %w", err)
|
||||
}
|
||||
|
||||
// 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
|
||||
body, err = sjson.DeleteBytes(body, "stream")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove stream field: %w", err)
|
||||
}
|
||||
|
||||
// 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message)
|
||||
// 参考 litellm: _convert_output_format_to_inline_schema()
|
||||
body = convertOutputFormatToInlineSchema(body)
|
||||
|
||||
// 移除 output_config 字段(Bedrock Invoke 不支持)
|
||||
body, err = sjson.DeleteBytes(body, "output_config")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove output_config field: %w", err)
|
||||
}
|
||||
|
||||
// 移除工具定义中的 custom 字段
|
||||
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true},
|
||||
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
|
||||
body = removeCustomFieldFromTools(body)
|
||||
|
||||
// 清理 cache_control 中 Bedrock 不支持的字段
|
||||
body = sanitizeBedrockCacheControl(body, modelID)
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
|
||||
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
|
||||
betaTokens := parseAnthropicBetaHeader(betaHeader)
|
||||
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
|
||||
return filterBedrockBetaTokens(betaTokens)
|
||||
}
|
||||
|
||||
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
|
||||
// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中
|
||||
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
|
||||
func convertOutputFormatToInlineSchema(body []byte) []byte {
|
||||
outputFormat := gjson.GetBytes(body, "output_format")
|
||||
if !outputFormat.Exists() || !outputFormat.IsObject() {
|
||||
return body
|
||||
}
|
||||
|
||||
// 先从请求体中移除 output_format
|
||||
body, _ = sjson.DeleteBytes(body, "output_format")
|
||||
|
||||
schema := outputFormat.Get("schema")
|
||||
if !schema.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// 找到最后一条 user message
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
msgArr := messages.Array()
|
||||
lastUserIdx := -1
|
||||
for i := len(msgArr) - 1; i >= 0; i-- {
|
||||
if msgArr[i].Get("role").String() == "user" {
|
||||
lastUserIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastUserIdx < 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
|
||||
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
content := msgArr[lastUserIdx].Get("content")
|
||||
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
|
||||
|
||||
if content.IsArray() {
|
||||
// 追加一个 text block 到 content 数组末尾
|
||||
idx := len(content.Array())
|
||||
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
|
||||
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
|
||||
} else if content.Type == gjson.String {
|
||||
// content 是纯字符串,转换为数组格式
|
||||
originalText := content.String()
|
||||
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
|
||||
{"type": "text", "text": originalText},
|
||||
{"type": "text", "text": string(schemaJSON)},
|
||||
})
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
|
||||
func removeCustomFieldFromTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return body
|
||||
}
|
||||
var err error
|
||||
for i := range tools.Array() {
|
||||
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
|
||||
if err != nil {
|
||||
// 删除失败不影响整体流程,跳过
|
||||
continue
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
|
||||
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
|
||||
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
|
||||
|
||||
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
|
||||
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h")
|
||||
func isBedrockClaude45OrNewer(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 5)
|
||||
}
|
||||
|
||||
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
|
||||
// Bedrock 不支持的字段:
|
||||
// - scope:Bedrock 不支持(如 "global" 跨请求缓存)
|
||||
// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
|
||||
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
|
||||
isClaude45 := isBedrockClaude45OrNewer(modelID)
|
||||
|
||||
// 清理 system 数组中的 cache_control
|
||||
systemArr := gjson.GetBytes(body, "system")
|
||||
if systemArr.Exists() && systemArr.IsArray() {
|
||||
for i, item := range systemArr.Array() {
|
||||
if !item.IsObject() {
|
||||
continue
|
||||
}
|
||||
cc := item.Get("cache_control")
|
||||
if !cc.Exists() || !cc.IsObject() {
|
||||
continue
|
||||
}
|
||||
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
|
||||
}
|
||||
}
|
||||
|
||||
// 清理 messages 中的 cache_control
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
for mi, msg := range messages.Array() {
|
||||
if !msg.IsObject() {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
for ci, block := range content.Array() {
|
||||
if !block.IsObject() {
|
||||
continue
|
||||
}
|
||||
cc := block.Get("cache_control")
|
||||
if !cc.Exists() || !cc.IsObject() {
|
||||
continue
|
||||
}
|
||||
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
|
||||
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
|
||||
// Bedrock 不支持 scope(如 "global")
|
||||
if cc.Get("scope").Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, basePath+".scope")
|
||||
}
|
||||
|
||||
// ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
|
||||
ttl := cc.Get("ttl")
|
||||
if ttl.Exists() {
|
||||
shouldRemove := true
|
||||
if isClaude45 {
|
||||
v := ttl.String()
|
||||
if v == "5m" || v == "1h" {
|
||||
shouldRemove = false
|
||||
}
|
||||
}
|
||||
if shouldRemove {
|
||||
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
|
||||
func parseAnthropicBetaHeader(header string) []string {
|
||||
header = strings.TrimSpace(header)
|
||||
if header == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
|
||||
var parsed []any
|
||||
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
|
||||
tokens := make([]string, 0, len(parsed))
|
||||
for _, item := range parsed {
|
||||
token := strings.TrimSpace(fmt.Sprint(item))
|
||||
if token != "" {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
}
|
||||
var tokens []string
|
||||
for _, part := range strings.Split(header, ",") {
|
||||
t := strings.TrimSpace(part)
|
||||
if t != "" {
|
||||
tokens = append(tokens, t)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
|
||||
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
|
||||
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
|
||||
var bedrockSupportedBetaTokens = map[string]bool{
|
||||
"computer-use-2025-01-24": true,
|
||||
"computer-use-2025-11-24": true,
|
||||
"context-1m-2025-08-07": true,
|
||||
"context-management-2025-06-27": true,
|
||||
"compact-2026-01-12": true,
|
||||
"interleaved-thinking-2025-05-14": true,
|
||||
"tool-search-tool-2025-10-19": true,
|
||||
"tool-examples-2025-10-29": true,
|
||||
}
|
||||
|
||||
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
|
||||
// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头
|
||||
var bedrockBetaTokenTransforms = map[string]string{
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
}
|
||||
|
||||
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
|
||||
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
|
||||
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
|
||||
//
|
||||
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token,
|
||||
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
|
||||
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
|
||||
seen := make(map[string]bool, len(tokens))
|
||||
for _, t := range tokens {
|
||||
seen[t] = true
|
||||
}
|
||||
|
||||
inject := func(token string) {
|
||||
if !seen[token] {
|
||||
tokens = append(tokens, token)
|
||||
seen[token] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 检测 thinking / interleaved thinking
|
||||
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
|
||||
if gjson.GetBytes(body, "thinking").Exists() {
|
||||
inject("interleaved-thinking-2025-05-14")
|
||||
}
|
||||
|
||||
// 检测 computer_use 工具
|
||||
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() && tools.IsArray() {
|
||||
toolSearchUsed := false
|
||||
programmaticToolCallingUsed := false
|
||||
inputExamplesUsed := false
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
if strings.HasPrefix(toolType, "computer_20") {
|
||||
inject("computer-use-2025-11-24")
|
||||
}
|
||||
if isBedrockToolSearchType(toolType) {
|
||||
toolSearchUsed = true
|
||||
}
|
||||
if hasCodeExecutionAllowedCallers(tool) {
|
||||
programmaticToolCallingUsed = true
|
||||
}
|
||||
if hasInputExamples(tool) {
|
||||
inputExamplesUsed = true
|
||||
}
|
||||
}
|
||||
if programmaticToolCallingUsed || inputExamplesUsed {
|
||||
// programmatic tool calling 和 input examples 需要 advanced-tool-use,
|
||||
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
|
||||
inject("advanced-tool-use-2025-11-20")
|
||||
}
|
||||
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
|
||||
// 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头,
|
||||
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
|
||||
if !programmaticToolCallingUsed && !inputExamplesUsed {
|
||||
inject("tool-search-tool-2025-10-19")
|
||||
} else {
|
||||
inject("advanced-tool-use-2025-11-20")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isBedrockToolSearchType(toolType string) bool {
|
||||
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
|
||||
}
|
||||
|
||||
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
|
||||
allowedCallers := tool.Get("allowed_callers")
|
||||
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
|
||||
return true
|
||||
}
|
||||
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
|
||||
}
|
||||
|
||||
func hasInputExamples(tool gjson.Result) bool {
|
||||
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
|
||||
return true
|
||||
}
|
||||
arr := tool.Get("function.input_examples")
|
||||
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
|
||||
}
|
||||
|
||||
func containsStringInJSONArray(result gjson.Result, target string) bool {
|
||||
if !result.Exists() || !result.IsArray() {
|
||||
return false
|
||||
}
|
||||
for _, item := range result.Array() {
|
||||
if item.String() == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
|
||||
// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持
|
||||
func bedrockModelSupportsToolSearch(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
// Haiku 不支持 tool search
|
||||
if strings.Contains(lower, "haiku") {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 5)
|
||||
}
|
||||
|
||||
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
|
||||
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool)
|
||||
// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等)
|
||||
// 3. 自动关联 tool-examples(当 tool-search-tool 存在时)
|
||||
func filterBedrockBetaTokens(tokens []string) []string {
|
||||
seen := make(map[string]bool, len(tokens))
|
||||
var result []string
|
||||
|
||||
for _, t := range tokens {
|
||||
// 应用转换规则
|
||||
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
|
||||
t = replacement
|
||||
}
|
||||
// 只保留白名单中的 token,且去重
|
||||
if bedrockSupportedBetaTokens[t] && !seen[t] {
|
||||
result = append(result, t)
|
||||
seen[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
|
||||
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
|
||||
result = append(result, "tool-examples-2025-10-29")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
659
backend/internal/service/bedrock_request_test.go
Normal file
659
backend/internal/service/bedrock_request_test.go
Normal file
@@ -0,0 +1,659 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
|
||||
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// anthropic_version 应被注入
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
// model 和 stream 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||
// max_tokens 应保留
|
||||
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
|
||||
t.Run("schema inlined into last user message array content", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
// schema 应内联到最后一条 user message 的 content 数组末尾
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "text", contentArr[1].Get("type").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
|
||||
})
|
||||
|
||||
t.Run("schema inlined into string content", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
|
||||
})
|
||||
|
||||
t.Run("no schema field just removes output_format", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
})
|
||||
|
||||
t.Run("no messages just removes output_format", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||
}
|
||||
|
||||
func TestRemoveCustomFieldFromTools(t *testing.T) {
|
||||
input := `{
|
||||
"tools": [
|
||||
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
|
||||
{"name":"tool2","description":"desc2"},
|
||||
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
|
||||
]
|
||||
}`
|
||||
result := removeCustomFieldFromTools([]byte(input))
|
||||
|
||||
tools := gjson.GetBytes(result, "tools").Array()
|
||||
require.Len(t, tools, 3)
|
||||
// custom 应被移除
|
||||
assert.False(t, tools[0].Get("custom").Exists())
|
||||
// name/description 应保留
|
||||
assert.Equal(t, "tool1", tools[0].Get("name").String())
|
||||
assert.Equal(t, "desc1", tools[0].Get("description").String())
|
||||
// 没有 custom 的工具不受影响
|
||||
assert.Equal(t, "tool2", tools[1].Get("name").String())
|
||||
// 第三个工具的 custom 也应被移除
|
||||
assert.False(t, tools[2].Get("custom").Exists())
|
||||
assert.Equal(t, "tool3", tools[2].Get("name").String())
|
||||
}
|
||||
|
||||
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}]}`
|
||||
result := removeCustomFieldFromTools([]byte(input))
|
||||
// 无 tools 时不改变原始数据
|
||||
assert.JSONEq(t, input, string(result))
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
|
||||
}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
|
||||
// scope 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
|
||||
// type 应保留
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||
}`
|
||||
// 旧模型(Claude 3.5)不支持 ttl
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||
}`
|
||||
// Claude 4.5+ 支持 "5m" 和 "1h"
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||
|
||||
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
|
||||
}`
|
||||
// Claude 4.5 不支持 "10m"
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
|
||||
input := `{
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||
}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
|
||||
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
|
||||
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
// 无 cache_control 时不改变原始数据
|
||||
assert.JSONEq(t, input, string(result))
|
||||
}
|
||||
|
||||
func TestIsBedrockClaude45OrNewer(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
expect bool
|
||||
}{
|
||||
{"us.anthropic.claude-opus-4-6-v1", true},
|
||||
{"us.anthropic.claude-sonnet-4-6", true},
|
||||
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
|
||||
{"anthropic.claude-3-opus-20240229-v1:0", false},
|
||||
{"anthropic.claude-3-haiku-20240307-v1:0", false},
|
||||
// 未来版本应自动支持
|
||||
{"us.anthropic.claude-sonnet-5-0-v1", true},
|
||||
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||
// 旧版本
|
||||
{"anthropic.claude-opus-4-1-v1", false},
|
||||
{"anthropic.claude-sonnet-4-0-v1", false},
|
||||
// 非 Claude 模型
|
||||
{"amazon.nova-pro-v1", false},
|
||||
{"meta.llama3-70b", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||
// 模拟一个完整的 Claude Code 请求
|
||||
input := `{
|
||||
"model": "claude-opus-4-6",
|
||||
"stream": true,
|
||||
"max_tokens": 16384,
|
||||
"output_format": {"type": "json", "schema": {"result": "string"}},
|
||||
"output_config": {"max_tokens": 100},
|
||||
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
|
||||
],
|
||||
"tools": [
|
||||
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
|
||||
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
|
||||
]
|
||||
}`
|
||||
|
||||
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 基本字段
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
|
||||
|
||||
// anthropic_beta 应包含所有 beta tokens
|
||||
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, betaArr, 3)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
|
||||
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
|
||||
|
||||
// output_format 应被移除,schema 内联到最后一条 user message
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||
// content 数组:原始 text block + 内联 schema block
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "hello", contentArr[0].Get("text").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
|
||||
|
||||
// tools 中的 custom 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
|
||||
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
|
||||
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
|
||||
|
||||
// cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||
|
||||
t.Run("empty beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("single beta token", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
})
|
||||
|
||||
t.Run("json array beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAnthropicBetaHeader(t *testing.T) {
|
||||
assert.Nil(t, parseAnthropicBetaHeader(""))
|
||||
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
|
||||
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
|
||||
}
|
||||
|
||||
func TestFilterBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("supported tokens pass through", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, tokens, result)
|
||||
})
|
||||
|
||||
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
|
||||
tokens := []string{"advanced-tool-use-2025-11-20"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
// tool-examples 自动关联
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
|
||||
tokens := []string{"tool-search-tool-2025-10-19"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
|
||||
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "tool-examples-2025-10-29" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("empty input returns nil", func(t *testing.T) {
|
||||
result := filterBedrockBetaTokens(nil)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("all unsupported returns nil", func(t *testing.T) {
|
||||
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
|
||||
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||
|
||||
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"advanced-tool-use-2025-11-20")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
|
||||
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestBedrockCrossRegionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
region string
|
||||
expect string
|
||||
}{
|
||||
// US regions
|
||||
{"us-east-1", "us"},
|
||||
{"us-east-2", "us"},
|
||||
{"us-west-1", "us"},
|
||||
{"us-west-2", "us"},
|
||||
// GovCloud
|
||||
{"us-gov-east-1", "us-gov"},
|
||||
{"us-gov-west-1", "us-gov"},
|
||||
// EU regions
|
||||
{"eu-west-1", "eu"},
|
||||
{"eu-west-2", "eu"},
|
||||
{"eu-west-3", "eu"},
|
||||
{"eu-central-1", "eu"},
|
||||
{"eu-central-2", "eu"},
|
||||
{"eu-north-1", "eu"},
|
||||
{"eu-south-1", "eu"},
|
||||
// APAC regions
|
||||
{"ap-northeast-1", "jp"},
|
||||
{"ap-northeast-2", "apac"},
|
||||
{"ap-southeast-1", "apac"},
|
||||
{"ap-southeast-2", "au"},
|
||||
{"ap-south-1", "apac"},
|
||||
// Canada / South America fallback to us
|
||||
{"ca-central-1", "us"},
|
||||
{"sa-east-1", "us"},
|
||||
// Unknown defaults to us
|
||||
{"me-south-1", "us"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.region, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBedrockModelID(t *testing.T) {
|
||||
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "eu-west-1",
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
|
||||
})
|
||||
|
||||
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "ap-southeast-2",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-opus-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
||||
})
|
||||
|
||||
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
"aws_force_global": "true",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
|
||||
})
|
||||
|
||||
t.Run("direct bedrock model id passes through", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
|
||||
})
|
||||
|
||||
t.Run("unsupported alias returns false", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
|
||||
t.Run("no duplicate when already present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "interleaved-thinking-2025-05-14" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "computer-use-2025-11-24")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||
// 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool)
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
|
||||
})
|
||||
|
||||
t.Run("no injection for regular tools", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("no injection when no features detected", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("preserves existing tokens", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "context-1m-2025-08-07")
|
||||
assert.Contains(t, result, "compact-2026-01-12")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
found := false
|
||||
for _, v := range arr {
|
||||
if v.String() == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "interleaved-thinking should be auto-injected")
|
||||
})
|
||||
|
||||
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
names := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
names[i] = v.String()
|
||||
}
|
||||
assert.Contains(t, names, "context-1m-2025-08-07")
|
||||
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelID string
|
||||
region string
|
||||
expect string
|
||||
}{
|
||||
// US region — no change needed
|
||||
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
|
||||
// EU region — replace us → eu
|
||||
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
|
||||
// APAC region — jp and au have dedicated prefixes per AWS docs
|
||||
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
|
||||
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||
// eu → us (user manually set eu prefix, moved to us region)
|
||||
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
|
||||
// global prefix — replace to match region
|
||||
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||
// No known prefix — leave unchanged
|
||||
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
|
||||
// GovCloud — uses independent us-gov prefix
|
||||
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||
// Force global (special region value)
|
||||
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
|
||||
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
|
||||
})
|
||||
}
|
||||
}
|
||||
67
backend/internal/service/bedrock_signer.go
Normal file
67
backend/internal/service/bedrock_signer.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
)
|
||||
|
||||
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
|
||||
type BedrockSigner struct {
|
||||
credentials aws.Credentials
|
||||
region string
|
||||
signer *v4.Signer
|
||||
}
|
||||
|
||||
// NewBedrockSigner 创建 BedrockSigner
|
||||
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
|
||||
return &BedrockSigner{
|
||||
credentials: aws.Credentials{
|
||||
AccessKeyID: accessKeyID,
|
||||
SecretAccessKey: secretAccessKey,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
region: region,
|
||||
signer: v4.NewSigner(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
|
||||
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
|
||||
accessKeyID := account.GetCredential("aws_access_key_id")
|
||||
if accessKeyID == "" {
|
||||
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
|
||||
}
|
||||
secretAccessKey := account.GetCredential("aws_secret_access_key")
|
||||
if secretAccessKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
|
||||
}
|
||||
region := account.GetCredential("aws_region")
|
||||
if region == "" {
|
||||
region = defaultBedrockRegion
|
||||
}
|
||||
sessionToken := account.GetCredential("aws_session_token") // 可选
|
||||
|
||||
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
|
||||
}
|
||||
|
||||
// SignRequest 对 HTTP 请求进行 SigV4 签名
|
||||
// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。
|
||||
// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header,
|
||||
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
|
||||
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。
|
||||
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
|
||||
payloadHash := sha256Hash(body)
|
||||
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
|
||||
}
|
||||
|
||||
func sha256Hash(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
35
backend/internal/service/bedrock_signer_test.go
Normal file
35
backend/internal/service/bedrock_signer_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_access_key_id": "test-akid",
|
||||
"aws_secret_access_key": "test-secret",
|
||||
},
|
||||
}
|
||||
|
||||
signer, err := NewBedrockSignerFromAccount(account)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
assert.Equal(t, defaultBedrockRegion, signer.region)
|
||||
}
|
||||
|
||||
func TestFilterBetaTokens(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
|
||||
filterSet := map[string]struct{}{
|
||||
"tool-search-tool-2025-10-19": {},
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
|
||||
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
|
||||
assert.Nil(t, filterBetaTokens(nil, filterSet))
|
||||
}
|
||||
414
backend/internal/service/bedrock_stream.go
Normal file
414
backend/internal/service/bedrock_stream.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
|
||||
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
|
||||
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
|
||||
func (s *GatewayService) handleBedrockStreamingResponse(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
startTime time.Time,
|
||||
model string,
|
||||
) (*streamingResult, error) {
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
|
||||
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
|
||||
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
|
||||
// 我们使用 EventStream decoder 来正确解析。
|
||||
decoder := newBedrockEventStreamDecoder(resp.Body)
|
||||
|
||||
type decodeEvent struct {
|
||||
payload []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan decodeEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev decodeEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt atomic.Int64
|
||||
lastReadAt.Store(time.Now().UnixNano())
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
for {
|
||||
payload, err := decoder.Decode()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
_ = sendEvent(decodeEvent{err: err})
|
||||
return
|
||||
}
|
||||
lastReadAt.Store(time.Now().UnixNano())
|
||||
if !sendEvent(decodeEvent{payload: payload}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
if !clientDisconnected {
|
||||
flusher.Flush()
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
// payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据)
|
||||
sseData := extractBedrockChunkData(ev.payload)
|
||||
if sseData == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||||
// 同时移除该字段避免透传给客户端
|
||||
sseData = transformBedrockInvocationMetrics(sseData)
|
||||
|
||||
// 解析 SSE 事件数据提取 usage
|
||||
s.parseSSEUsagePassthrough(string(sseData), usage)
|
||||
|
||||
// 确定 SSE event type
|
||||
eventType := gjson.GetBytes(sseData, "type").String()
|
||||
|
||||
// 写入标准 SSE 格式
|
||||
if !clientDisconnected {
|
||||
var writeErr error
|
||||
if eventType != "" {
|
||||
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
|
||||
} else {
|
||||
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
|
||||
}
|
||||
if writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, lastReadAt.Load())
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
|
||||
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
|
||||
func extractBedrockChunkData(payload []byte) []byte {
|
||||
b64 := gjson.GetBytes(payload, "bytes").String()
|
||||
if b64 == "" {
|
||||
return nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
|
||||
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
|
||||
//
|
||||
// Bedrock Invoke 返回的 message_delta 事件可能包含:
|
||||
//
|
||||
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
|
||||
//
|
||||
// 转换为:
|
||||
//
|
||||
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
|
||||
func transformBedrockInvocationMetrics(data []byte) []byte {
|
||||
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
|
||||
if !metrics.Exists() || !metrics.IsObject() {
|
||||
return data
|
||||
}
|
||||
|
||||
// 移除 Bedrock 特有字段
|
||||
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
|
||||
|
||||
// 如果已有标准 usage 字段,不覆盖
|
||||
if gjson.GetBytes(data, "usage").Exists() {
|
||||
return data
|
||||
}
|
||||
|
||||
// 转换 camelCase → snake_case 写入 usage
|
||||
inputTokens := metrics.Get("inputTokenCount")
|
||||
outputTokens := metrics.Get("outputTokenCount")
|
||||
if inputTokens.Exists() {
|
||||
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
|
||||
}
|
||||
if outputTokens.Exists() {
|
||||
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
|
||||
// EventStream 帧格式:
|
||||
//
|
||||
// [total_byte_length: 4 bytes]
|
||||
// [headers_byte_length: 4 bytes]
|
||||
// [prelude_crc: 4 bytes]
|
||||
// [headers: variable]
|
||||
// [payload: variable]
|
||||
// [message_crc: 4 bytes]
|
||||
type bedrockEventStreamDecoder struct {
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
|
||||
return &bedrockEventStreamDecoder{
|
||||
reader: bufio.NewReaderSize(r, 64*1024),
|
||||
}
|
||||
}
|
||||
|
||||
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
|
||||
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
|
||||
for {
|
||||
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
|
||||
prelude := make([]byte, 12)
|
||||
if _, err := io.ReadFull(d.reader, prelude); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE)
|
||||
preludeCRC := bedrockReadUint32(prelude[8:12])
|
||||
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
|
||||
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
|
||||
}
|
||||
|
||||
totalLength := bedrockReadUint32(prelude[0:4])
|
||||
headersLength := bedrockReadUint32(prelude[4:8])
|
||||
|
||||
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
|
||||
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
|
||||
}
|
||||
|
||||
// 读取 headers + payload + message_crc
|
||||
remaining := int(totalLength) - 12
|
||||
if remaining <= 0 {
|
||||
continue
|
||||
}
|
||||
data := make([]byte, remaining)
|
||||
if _, err := io.ReadFull(d.reader, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证 message CRC(覆盖 prelude + headers + payload)
|
||||
messageCRC := bedrockReadUint32(data[len(data)-4:])
|
||||
h := crc32.New(crc32IEEETable)
|
||||
_, _ = h.Write(prelude)
|
||||
_, _ = h.Write(data[:len(data)-4])
|
||||
if h.Sum32() != messageCRC {
|
||||
return nil, fmt.Errorf("eventstream message CRC mismatch")
|
||||
}
|
||||
|
||||
// 解析 headers
|
||||
headers := data[:headersLength]
|
||||
payload := data[headersLength : len(data)-4] // 去掉 message_crc
|
||||
|
||||
// 从 headers 中提取 :event-type
|
||||
eventType := extractEventStreamHeaderValue(headers, ":event-type")
|
||||
|
||||
// 只处理 chunk 事件
|
||||
if eventType == "chunk" {
|
||||
// payload 是完整的 JSON,包含 bytes 字段
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 检查异常事件
|
||||
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
|
||||
if exceptionType != "" {
|
||||
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
|
||||
}
|
||||
|
||||
messageType := extractEventStreamHeaderValue(headers, ":message-type")
|
||||
if messageType == "exception" || messageType == "error" {
|
||||
return nil, fmt.Errorf("bedrock error: %s", string(payload))
|
||||
}
|
||||
|
||||
// 跳过其他事件类型(如 initial-response)
|
||||
}
|
||||
}
|
||||
|
||||
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
|
||||
// EventStream header 格式:
|
||||
//
|
||||
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
|
||||
//
|
||||
// value_type = 7 表示 string 类型,前 2 bytes 为长度
|
||||
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
|
||||
pos := 0
|
||||
for pos < len(headers) {
|
||||
if pos >= len(headers) {
|
||||
break
|
||||
}
|
||||
nameLen := int(headers[pos])
|
||||
pos++
|
||||
if pos+nameLen > len(headers) {
|
||||
break
|
||||
}
|
||||
name := string(headers[pos : pos+nameLen])
|
||||
pos += nameLen
|
||||
|
||||
if pos >= len(headers) {
|
||||
break
|
||||
}
|
||||
valueType := headers[pos]
|
||||
pos++
|
||||
|
||||
switch valueType {
|
||||
case 7: // string
|
||||
if pos+2 > len(headers) {
|
||||
return ""
|
||||
}
|
||||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||
pos += 2
|
||||
if pos+valueLen > len(headers) {
|
||||
return ""
|
||||
}
|
||||
value := string(headers[pos : pos+valueLen])
|
||||
pos += valueLen
|
||||
if name == targetName {
|
||||
return value
|
||||
}
|
||||
case 0: // bool true
|
||||
if name == targetName {
|
||||
return "true"
|
||||
}
|
||||
case 1: // bool false
|
||||
if name == targetName {
|
||||
return "false"
|
||||
}
|
||||
case 2: // byte
|
||||
pos++
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 3: // short
|
||||
pos += 2
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 4: // int
|
||||
pos += 4
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 5: // long
|
||||
pos += 8
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 6: // bytes
|
||||
if pos+2 > len(headers) {
|
||||
return ""
|
||||
}
|
||||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||
pos += 2 + valueLen
|
||||
case 8: // timestamp
|
||||
pos += 8
|
||||
case 9: // uuid
|
||||
pos += 16
|
||||
default:
|
||||
return "" // 未知类型,无法继续解析
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
|
||||
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
func bedrockReadUint32(b []byte) uint32 {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
func bedrockReadUint16(b []byte) uint16 {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
261
backend/internal/service/bedrock_stream_test.go
Normal file
261
backend/internal/service/bedrock_stream_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestExtractBedrockChunkData(t *testing.T) {
|
||||
t.Run("valid base64 payload", func(t *testing.T) {
|
||||
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(original))
|
||||
payload := []byte(`{"bytes":"` + b64 + `"}`)
|
||||
|
||||
result := extractBedrockChunkData(payload)
|
||||
require.NotNil(t, result)
|
||||
assert.JSONEq(t, original, string(result))
|
||||
})
|
||||
|
||||
t.Run("empty bytes field", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("no bytes field", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("invalid base64", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransformBedrockInvocationMetrics(t *testing.T) {
|
||||
t.Run("converts metrics to usage", func(t *testing.T) {
|
||||
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
|
||||
// amazon-bedrock-invocationMetrics should be removed
|
||||
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||
// usage should be set
|
||||
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
|
||||
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||
// original fields preserved
|
||||
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
|
||||
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
|
||||
})
|
||||
|
||||
t.Run("no metrics present", func(t *testing.T) {
|
||||
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("does not overwrite existing usage", func(t *testing.T) {
|
||||
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
|
||||
// metrics removed but existing usage preserved
|
||||
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractEventStreamHeaderValue(t *testing.T) {
|
||||
// Build a header with :event-type = "chunk" (string type = 7)
|
||||
buildStringHeader := func(name, value string) []byte {
|
||||
var buf bytes.Buffer
|
||||
// name length (1 byte)
|
||||
_ = buf.WriteByte(byte(len(name)))
|
||||
// name
|
||||
_, _ = buf.WriteString(name)
|
||||
// value type (7 = string)
|
||||
_ = buf.WriteByte(7)
|
||||
// value length (2 bytes, big-endian)
|
||||
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
|
||||
// value
|
||||
_, _ = buf.WriteString(value)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
t.Run("find string header", func(t *testing.T) {
|
||||
headers := buildStringHeader(":event-type", "chunk")
|
||||
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||
})
|
||||
|
||||
t.Run("header not found", func(t *testing.T) {
|
||||
headers := buildStringHeader(":event-type", "chunk")
|
||||
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||
})
|
||||
|
||||
t.Run("multiple headers", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
|
||||
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
|
||||
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
|
||||
|
||||
headers := buf.Bytes()
|
||||
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
|
||||
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||
})
|
||||
|
||||
t.Run("empty headers", func(t *testing.T) {
|
||||
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestBedrockEventStreamDecoder(t *testing.T) {
|
||||
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
|
||||
buildFrame := func(eventType string, payload []byte) []byte {
|
||||
// Build headers
|
||||
var headersBuf bytes.Buffer
|
||||
// :event-type header
|
||||
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||
_, _ = headersBuf.WriteString(":event-type")
|
||||
_ = headersBuf.WriteByte(7) // string type
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
|
||||
_, _ = headersBuf.WriteString(eventType)
|
||||
// :message-type header
|
||||
_ = headersBuf.WriteByte(byte(len(":message-type")))
|
||||
_, _ = headersBuf.WriteString(":message-type")
|
||||
_ = headersBuf.WriteByte(7)
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
|
||||
_, _ = headersBuf.WriteString("event")
|
||||
|
||||
headers := headersBuf.Bytes()
|
||||
headersLen := uint32(len(headers))
|
||||
// total = 12 (prelude) + headers + payload + 4 (message_crc)
|
||||
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||
|
||||
// Prelude: total_length(4) + headers_length(4)
|
||||
var preludeBuf bytes.Buffer
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||
preludeBytes := preludeBuf.Bytes()
|
||||
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
|
||||
|
||||
// Build frame: prelude + prelude_crc + headers + payload
|
||||
var frame bytes.Buffer
|
||||
_, _ = frame.Write(preludeBytes)
|
||||
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
|
||||
_, _ = frame.Write(headers)
|
||||
_, _ = frame.Write(payload)
|
||||
|
||||
// Message CRC covers everything before itself
|
||||
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
|
||||
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
|
||||
return frame.Bytes()
|
||||
}
|
||||
|
||||
t.Run("decode chunk event", func(t *testing.T) {
|
||||
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
|
||||
frame := buildFrame("chunk", payload)
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
result, err := decoder.Decode()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, payload, result)
|
||||
})
|
||||
|
||||
t.Run("skip non-chunk events", func(t *testing.T) {
|
||||
// Write initial-response followed by chunk
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
|
||||
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
|
||||
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(&buf)
|
||||
result, err := decoder.Decode()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, chunkPayload, result)
|
||||
})
|
||||
|
||||
t.Run("EOF on empty input", func(t *testing.T) {
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
|
||||
_, err := decoder.Decode()
|
||||
assert.Equal(t, io.EOF, err)
|
||||
})
|
||||
|
||||
t.Run("corrupted prelude CRC", func(t *testing.T) {
|
||||
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||
// Corrupt the prelude CRC (bytes 8-11)
|
||||
frame[8] ^= 0xFF
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||
})
|
||||
|
||||
t.Run("corrupted message CRC", func(t *testing.T) {
|
||||
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||
// Corrupt the message CRC (last 4 bytes)
|
||||
frame[len(frame)-1] ^= 0xFF
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "message CRC mismatch")
|
||||
})
|
||||
|
||||
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
|
||||
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
|
||||
payload := []byte(`{"bytes":"dGVzdA=="}`)
|
||||
|
||||
var headersBuf bytes.Buffer
|
||||
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||
_, _ = headersBuf.WriteString(":event-type")
|
||||
_ = headersBuf.WriteByte(7)
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
|
||||
_, _ = headersBuf.WriteString("chunk")
|
||||
|
||||
headers := headersBuf.Bytes()
|
||||
headersLen := uint32(len(headers))
|
||||
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||
|
||||
var preludeBuf bytes.Buffer
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||
preludeBytes := preludeBuf.Bytes()
|
||||
|
||||
var frame bytes.Buffer
|
||||
_, _ = frame.Write(preludeBytes)
|
||||
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
|
||||
_, _ = frame.Write(headers)
|
||||
_, _ = frame.Write(payload)
|
||||
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBedrockURL(t *testing.T) {
|
||||
t.Run("stream URL with colon in model ID", func(t *testing.T) {
|
||||
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
|
||||
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
|
||||
})
|
||||
|
||||
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
|
||||
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
|
||||
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
|
||||
})
|
||||
|
||||
t.Run("model ID without colon", func(t *testing.T) {
|
||||
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
|
||||
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
|
||||
})
|
||||
}
|
||||
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
if usageErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||
if dedupErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,12 +12,18 @@ import (
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
recomputeCalls int
|
||||
cleanupUsageCalls int
|
||||
cleanupDedupCalls int
|
||||
ensurePartitionCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
cleanupDedupErr error
|
||||
ensurePartitionErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupUsageCalls++
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupDedupCalls++
|
||||
return s.cleanupDedupErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
s.ensurePartitionCalls++
|
||||
return s.ensurePartitionErr
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
UsageBillingDedupDays: 2,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
|
||||
@@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user spending ranking: %w", err)
|
||||
}
|
||||
return ranking, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
|
||||
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,10 +29,12 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
||||
require.Equal(t, 5, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
|
||||
"",
|
||||
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
|
||||
_ = pr.Close()
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 9, result.usage.InputTokens)
|
||||
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 8, result.usage.InputTokens)
|
||||
|
||||
371
backend/internal/service/gateway_record_usage_test.go
Normal file
371
backend/internal/service/gateway_record_usage_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
return NewGatewayService(
|
||||
nil,
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
type openAIRecordUsageBestEffortLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
bestEffortErr error
|
||||
createErr error
|
||||
bestEffortCalls int
|
||||
createCalls int
|
||||
lastLog *UsageLog
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
|
||||
s.bestEffortCalls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.bestEffortErr
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.createCalls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return false, s.createErr
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 501,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_hash",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_fallback",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_not_persisted",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 503,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 603},
|
||||
Account: &Account{ID: 703},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_long_context_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 12,
|
||||
OutputTokens: 8,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 502,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 602},
|
||||
Account: &Account{ID: 702},
|
||||
LongContextThreshold: 200000,
|
||||
LongContextMultiplier: 2,
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 504},
|
||||
User: &User{ID: 604},
|
||||
Account: &Account{ID: 704},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
|
||||
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "upstream-volatile-456",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 506},
|
||||
User: &User{ID: 606},
|
||||
Account: &Account{ID: 706},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 507},
|
||||
User: &User{ID: 607},
|
||||
Account: &Account{ID: 707},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
|
||||
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
|
||||
}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_drop_usage_log",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 508},
|
||||
User: &User{ID: 608},
|
||||
Account: &Account{ID: 708},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.bestEffortCalls)
|
||||
require.Equal(t, 0, usageRepo.createCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_billing_fail",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 505},
|
||||
User: &User{ID: 605},
|
||||
Account: &Account{ID: 705},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
267
backend/internal/service/gateway_service_bedrock_beta_test.go
Normal file
267
backend/internal/service/gateway_service_bedrock_beta_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type betaPolicySettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) {
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "advanced-tool-use-2025-11-20",
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "advanced tool use is blocked",
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal settings: %v", err)
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(
|
||||
&betaPolicySettingRepoStub{values: map[string]string{
|
||||
SettingKeyBetaPolicySettings: string(raw),
|
||||
}},
|
||||
&config.Config{},
|
||||
),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"advanced-tool-use-2025-11-20",
|
||||
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform")
|
||||
}
|
||||
if err.Error() != "advanced tool use is blocked" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) {
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "tool-search-tool-2025-10-19",
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal settings: %v", err)
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(
|
||||
&betaPolicySettingRepoStub{values: map[string]string{
|
||||
SettingKeyBetaPolicySettings: string(raw),
|
||||
}},
|
||||
&config.Config{},
|
||||
),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
betaTokens, err := svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"advanced-tool-use-2025-11-20",
|
||||
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
for _, token := range betaTokens {
|
||||
if token == "tool-search-tool-2025-10-19" {
|
||||
t.Fatalf("expected transformed Bedrock token to be filtered")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
|
||||
// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
||||
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
|
||||
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) {
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "interleaved-thinking-2025-05-14",
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "thinking is blocked",
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal settings: %v", err)
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(
|
||||
&betaPolicySettingRepoStub{values: map[string]string{
|
||||
SettingKeyBetaPolicySettings: string(raw),
|
||||
}},
|
||||
&config.Config{},
|
||||
),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
// header 中不带 beta token,但 body 中有 thinking 字段
|
||||
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"", // 空 header
|
||||
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected body-injected interleaved-thinking to be blocked")
|
||||
}
|
||||
if err.Error() != "thinking is blocked" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证:
|
||||
// 管理员 block 了 tool-search-tool,客户端不在 header 中带 beta token,
|
||||
// 但请求体包含 tool search 工具 → 自动注入后应被 block。
|
||||
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) {
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "tool-search-tool-2025-10-19",
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "tool search is blocked",
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal settings: %v", err)
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(
|
||||
&betaPolicySettingRepoStub{values: map[string]string{
|
||||
SettingKeyBetaPolicySettings: string(raw),
|
||||
}},
|
||||
&config.Config{},
|
||||
),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
// header 中不带 beta token,但 body 中有 tool_search_tool 工具
|
||||
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"",
|
||||
[]byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-sonnet-4-6",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected body-injected tool-search-tool to be blocked")
|
||||
}
|
||||
if err.Error() != "tool search is blocked" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证:
|
||||
// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。
|
||||
func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) {
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "computer-use-2025-11-24",
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "computer use is blocked",
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal settings: %v", err)
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(
|
||||
&betaPolicySettingRepoStub{values: map[string]string{
|
||||
SettingKeyBetaPolicySettings: string(raw),
|
||||
}},
|
||||
&config.Config{},
|
||||
),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
// body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use
|
||||
tokens, err := svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"",
|
||||
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
found := false
|
||||
for _, token := range tokens {
|
||||
if token == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected interleaved-thinking token to be present")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") {
|
||||
t.Fatalf("expected default Bedrock alias to be supported")
|
||||
}
|
||||
|
||||
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
|
||||
t.Fatalf("expected unsupported alias to be rejected for Bedrock account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "eu-west-1",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-*": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") {
|
||||
t.Fatalf("expected matched custom mapping to be supported")
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") {
|
||||
t.Fatalf("expected default Bedrock alias fallback to remain supported")
|
||||
}
|
||||
|
||||
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
|
||||
t.Fatalf("expected unsupported model to still be rejected")
|
||||
}
|
||||
}
|
||||
@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -129,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容遗留的 functions 和 function_call,转换为 tools 和 tool_choice
|
||||
if functionsRaw, ok := reqBody["functions"]; ok {
|
||||
if functions, k := functionsRaw.([]any); k {
|
||||
tools := make([]any, 0, len(functions))
|
||||
for _, f := range functions {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": "function",
|
||||
"function": f,
|
||||
})
|
||||
}
|
||||
reqBody["tools"] = tools
|
||||
}
|
||||
delete(reqBody, "functions")
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if fcRaw, ok := reqBody["function_call"]; ok {
|
||||
if fcStr, ok := fcRaw.(string); ok {
|
||||
// e.g. "auto", "none"
|
||||
reqBody["tool_choice"] = fcStr
|
||||
} else if fcObj, ok := fcRaw.(map[string]any); ok {
|
||||
// e.g. {"name": "my_func"}
|
||||
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
reqBody["tool_choice"] = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(reqBody, "function_call")
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if normalizeCodexTools(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
@@ -303,6 +338,18 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
|
||||
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
|
||||
fixIDPrefix := func(id string) string {
|
||||
if id == "" || strings.HasPrefix(id, "fc") {
|
||||
return id
|
||||
}
|
||||
if strings.HasPrefix(id, "call_") {
|
||||
return "fc" + strings.TrimPrefix(id, "call_")
|
||||
}
|
||||
return "fc_" + id
|
||||
}
|
||||
|
||||
if typ == "item_reference" {
|
||||
if !preserveReferences {
|
||||
continue
|
||||
@@ -311,6 +358,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
newItem["id"] = fixIDPrefix(id)
|
||||
}
|
||||
filtered = append(filtered, newItem)
|
||||
continue
|
||||
}
|
||||
@@ -330,10 +380,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
|
||||
if isCodexToolCallItemType(typ) {
|
||||
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
||||
callID, ok := m["call_id"].(string)
|
||||
if !ok || strings.TrimSpace(callID) == "" {
|
||||
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||
callID = id
|
||||
ensureCopy()
|
||||
newItem["call_id"] = id
|
||||
newItem["call_id"] = callID
|
||||
}
|
||||
}
|
||||
|
||||
if callID != "" {
|
||||
fixedCallID := fixIDPrefix(callID)
|
||||
if fixedCallID != callID {
|
||||
ensureCopy()
|
||||
newItem["call_id"] = fixedCallID
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -344,6 +404,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
} else {
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
fixedID := fixIDPrefix(id)
|
||||
if fixedID != id {
|
||||
ensureCopy()
|
||||
newItem["id"] = fixedID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
|
||||
@@ -33,12 +33,12 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "item_reference", first["type"])
|
||||
require.Equal(t, "ref1", first["id"])
|
||||
require.Equal(t, "fc_ref1", first["id"])
|
||||
|
||||
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "o1", second["id"])
|
||||
require.Equal(t, "fc_o1", second["id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
|
||||
@@ -3,39 +3,68 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIRecordUsageLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.calls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
type openAIRecordUsageBillingRepoStub struct {
|
||||
UsageBillingRepository
|
||||
|
||||
result *UsageBillingApplyResult
|
||||
err error
|
||||
calls int
|
||||
lastCmd *UsageBillingCommand
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
|
||||
s.calls++
|
||||
s.lastCmd = cmd
|
||||
s.lastCtxErr = ctx.Err()
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.result != nil {
|
||||
return s.result, nil
|
||||
}
|
||||
return &UsageBillingApplyResult{Applied: true}, nil
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
deductCalls int
|
||||
deductErr error
|
||||
lastAmount float64
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
s.deductCalls++
|
||||
s.lastAmount = amount
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.deductErr
|
||||
}
|
||||
|
||||
@@ -44,29 +73,35 @@ type openAIRecordUsageSubRepoStub struct {
|
||||
|
||||
incrementCalls int
|
||||
incrementErr error
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
s.incrementCalls++
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.incrementErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
lastQuotaCtxErr error
|
||||
lastRateLimitCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.quotaCalls++
|
||||
s.lastAmount = cost
|
||||
s.lastQuotaCtxErr = ctx.Err()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.rateLimitCalls++
|
||||
s.lastAmount = cost
|
||||
s.lastRateLimitCtxErr = ctx.Err()
|
||||
return s.err
|
||||
}
|
||||
|
||||
@@ -93,23 +128,38 @@ func i64p(v int64) *int64 {
|
||||
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
svc := NewOpenAIGatewayService(
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
rateRepo,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
)
|
||||
return svc
|
||||
}
|
||||
|
||||
return &OpenAIGatewayService{
|
||||
usageLogRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: subRepo,
|
||||
cfg: cfg,
|
||||
billingService: NewBillingService(cfg, nil),
|
||||
billingCacheService: &BillingCacheService{},
|
||||
deferredService: &DeferredService{},
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
),
|
||||
}
|
||||
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||
@@ -252,9 +302,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
@@ -272,11 +323,313 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_duplicate_billing_key",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10045,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20045},
|
||||
Account: &Account{ID: 30045},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_usage_log_error",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10041},
|
||||
User: &User{ID: 20041},
|
||||
Account: &Account{ID: 30041},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_not_persisted",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10043,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20043},
|
||||
Account: &Account{ID: 30043},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_detached_billing_ctx",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10042,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20042},
|
||||
Account: &Account{ID: 30042},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_detached_billing_repo_ctx",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10046},
|
||||
User: &User{ID: 20046},
|
||||
Account: &Account{ID: 30046},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.NoError(t, billingRepo.lastCtxErr)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.NoError(t, usageRepo.lastCtxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "openai_payload_hash",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "gpt-5",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10047},
|
||||
User: &User{ID: 20047},
|
||||
Account: &Account{ID: 30047},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123")
|
||||
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "upstream-openai-volatile-456",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10049},
|
||||
User: &User{ID: 20049},
|
||||
Account: &Account{ID: 30049},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10050},
|
||||
User: &User{ID: 20050},
|
||||
Account: &Account{ID: 30050},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_billing_fail",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10048},
|
||||
User: &User{ID: 20048},
|
||||
Account: &Account{ID: 30048},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
|
||||
@@ -301,6 +301,7 @@ var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICo
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
@@ -338,6 +339,7 @@ type OpenAIGatewayService struct {
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -355,6 +357,7 @@ func NewOpenAIGatewayService(
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
@@ -2119,7 +2122,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2326,7 +2331,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2663,6 +2670,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
sawTerminalEvent := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -2682,6 +2690,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if trimmedData == "[DONE]" {
|
||||
sawDone = true
|
||||
}
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
@@ -2699,19 +2710,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
if sawTerminalEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
err,
|
||||
ctx.Err(),
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
@@ -2725,12 +2731,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
@@ -3264,6 +3271,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sawTerminalEvent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
@@ -3294,22 +3302,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
}
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||
if scanErr == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sawTerminalEvent {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
}
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||
}
|
||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||
@@ -3332,6 +3345,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
dataBytes := []byte(data)
|
||||
if openAIStreamEventIsTerminal(data) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
@@ -3448,8 +3464,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return resultWithUsage(), nil
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -3547,11 +3562,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
return
|
||||
}
|
||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
|
||||
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
|
||||
if len(data) < 72 {
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(data, "type").String() != "response.completed" {
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType != "response.completed" && eventType != "response.done" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4007,14 +4023,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -4080,11 +4097,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
@@ -4125,29 +4143,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
|
||||
t.Fatalf("expected incomplete stream error, got %v", err)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 2, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
||||
require.Equal(t, 3, usage.InputTokens)
|
||||
require.Equal(t, 5, usage.OutputTokens)
|
||||
require.Equal(t, 2, usage.CacheReadInputTokens)
|
||||
|
||||
// done 事件同样可能携带最终 usage
|
||||
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
|
||||
require.Equal(t, 13, usage.InputTokens)
|
||||
require.Equal(t, 15, usage.OutputTokens)
|
||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
||||
|
||||
@@ -439,7 +439,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/json")
|
||||
@@ -453,7 +453,14 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: headers,
|
||||
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.output_text.delta","delta":"h"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
@@ -895,7 +902,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.EqualError(t, err, "stream usage incomplete: missing terminal event")
|
||||
require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流"))
|
||||
require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate"))
|
||||
@@ -911,11 +918,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t
|
||||
c.Request.Header.Set("x-stainless-timeout", "120000")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-default"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
svc := &OpenAIGatewayService{
|
||||
@@ -952,11 +964,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured
|
||||
c.Request.Header.Set("x-stainless-timeout", "120000")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-allow"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
svc := &OpenAIGatewayService{
|
||||
|
||||
@@ -335,7 +335,7 @@ func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedC
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":true,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
@@ -604,6 +604,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
|
||||
110
backend/internal/service/usage_billing.go
Normal file
110
backend/internal/service/usage_billing.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
|
||||
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
|
||||
|
||||
// UsageBillingCommand describes one billable request that must be applied at most once.
|
||||
type UsageBillingCommand struct {
|
||||
RequestID string
|
||||
APIKeyID int64
|
||||
RequestFingerprint string
|
||||
RequestPayloadHash string
|
||||
|
||||
UserID int64
|
||||
AccountID int64
|
||||
SubscriptionID *int64
|
||||
AccountType string
|
||||
Model string
|
||||
ServiceTier string
|
||||
ReasoningEffort string
|
||||
BillingType int8
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
ImageCount int
|
||||
MediaType string
|
||||
|
||||
BalanceCost float64
|
||||
SubscriptionCost float64
|
||||
APIKeyQuotaCost float64
|
||||
APIKeyRateLimitCost float64
|
||||
AccountQuotaCost float64
|
||||
}
|
||||
|
||||
func (c *UsageBillingCommand) Normalize() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.RequestID = strings.TrimSpace(c.RequestID)
|
||||
if strings.TrimSpace(c.RequestFingerprint) == "" {
|
||||
c.RequestFingerprint = buildUsageBillingFingerprint(c)
|
||||
}
|
||||
}
|
||||
|
||||
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
raw := fmt.Sprintf(
|
||||
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
|
||||
c.UserID,
|
||||
c.AccountID,
|
||||
c.APIKeyID,
|
||||
strings.TrimSpace(c.AccountType),
|
||||
strings.TrimSpace(c.Model),
|
||||
strings.TrimSpace(c.ServiceTier),
|
||||
strings.TrimSpace(c.ReasoningEffort),
|
||||
c.BillingType,
|
||||
c.InputTokens,
|
||||
c.OutputTokens,
|
||||
c.CacheCreationTokens,
|
||||
c.CacheReadTokens,
|
||||
c.ImageCount,
|
||||
strings.TrimSpace(c.MediaType),
|
||||
valueOrZero(c.SubscriptionID),
|
||||
c.BalanceCost,
|
||||
c.SubscriptionCost,
|
||||
c.APIKeyQuotaCost,
|
||||
c.APIKeyRateLimitCost,
|
||||
c.AccountQuotaCost,
|
||||
)
|
||||
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
|
||||
raw += "|" + payloadHash
|
||||
}
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func HashUsageRequestPayload(payload []byte) string {
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(payload)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func valueOrZero(v *int64) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
type UsageBillingApplyResult struct {
|
||||
Applied bool
|
||||
APIKeyQuotaExhausted bool
|
||||
}
|
||||
|
||||
type UsageBillingRepository interface {
|
||||
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
|
||||
}
|
||||
@@ -56,7 +56,8 @@ type cleanupRepoStub struct {
|
||||
}
|
||||
|
||||
type dashboardRepoStub struct {
|
||||
recomputeErr error
|
||||
recomputeErr error
|
||||
recomputeCalls int
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.recomputeErr
|
||||
}
|
||||
|
||||
@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
|
||||
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
dashboardRepo := &dashboardRepoStub{}
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
||||
|
||||
82
backend/internal/service/usage_log_create_result.go
Normal file
82
backend/internal/service/usage_log_create_result.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import "errors"
|
||||
|
||||
type usageLogCreateDisposition int
|
||||
|
||||
const (
|
||||
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
||||
usageLogCreateDispositionNotPersisted
|
||||
usageLogCreateDispositionDropped
|
||||
)
|
||||
|
||||
type UsageLogCreateError struct {
|
||||
err error
|
||||
disposition usageLogCreateDisposition
|
||||
}
|
||||
|
||||
func (e *UsageLogCreateError) Error() string {
|
||||
if e == nil || e.err == nil {
|
||||
return "usage log create error"
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *UsageLogCreateError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
func MarkUsageLogCreateNotPersisted(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLogCreateError{
|
||||
err: err,
|
||||
disposition: usageLogCreateDispositionNotPersisted,
|
||||
}
|
||||
}
|
||||
|
||||
func MarkUsageLogCreateDropped(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLogCreateError{
|
||||
err: err,
|
||||
disposition: usageLogCreateDispositionDropped,
|
||||
}
|
||||
}
|
||||
|
||||
func IsUsageLogCreateNotPersisted(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *UsageLogCreateError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return target.disposition == usageLogCreateDispositionNotPersisted
|
||||
}
|
||||
|
||||
func IsUsageLogCreateDropped(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *UsageLogCreateError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return target.disposition == usageLogCreateDispositionDropped
|
||||
}
|
||||
|
||||
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
||||
if inserted {
|
||||
return true
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return !IsUsageLogCreateNotPersisted(err)
|
||||
}
|
||||
13
backend/migrations/071_add_usage_billing_dedup.sql
Normal file
13
backend/migrations/071_add_usage_billing_dedup.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来
|
||||
-- 幂等执行:可重复运行
|
||||
|
||||
CREATE TABLE IF NOT EXISTS usage_billing_dedup (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
request_id VARCHAR(255) NOT NULL,
|
||||
api_key_id BIGINT NOT NULL,
|
||||
request_fingerprint VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key
|
||||
ON usage_billing_dedup (request_id, api_key_id);
|
||||
@@ -0,0 +1,7 @@
|
||||
-- usage_billing_dedup 是按时间追加写入的幂等窄表。
|
||||
-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。
|
||||
-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。
|
||||
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin
|
||||
ON usage_billing_dedup
|
||||
USING BRIN (created_at);
|
||||
10
backend/migrations/073_add_usage_billing_dedup_archive.sql
Normal file
10
backend/migrations/073_add_usage_billing_dedup_archive.sql
Normal file
@@ -0,0 +1,10 @@
|
||||
-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。
|
||||
|
||||
CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive (
|
||||
request_id VARCHAR(255) NOT NULL,
|
||||
api_key_id BIGINT NOT NULL,
|
||||
request_fingerprint VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (request_id, api_key_id)
|
||||
);
|
||||
@@ -105,7 +105,7 @@ EXPOSE 8080
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
# Run the application
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
|
||||
0
deploy/build_image.sh
Executable file → Normal file
0
deploy/build_image.sh
Executable file → Normal file
@@ -154,7 +154,7 @@ services:
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
@@ -94,7 +94,7 @@ services:
|
||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
@@ -146,7 +146,7 @@ services:
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
0
deploy/install-datamanagementd.sh
Executable file → Normal file
0
deploy/install-datamanagementd.sh
Executable file → Normal file
@@ -11,6 +11,7 @@ import type {
|
||||
GroupStat,
|
||||
ApiKeyUsageTrendPoint,
|
||||
UserUsageTrendPoint,
|
||||
UserSpendingRankingResponse,
|
||||
UsageRequestType
|
||||
} from '@/types'
|
||||
|
||||
@@ -201,6 +202,11 @@ export interface UserTrendResponse {
|
||||
granularity: string
|
||||
}
|
||||
|
||||
export interface UserSpendingRankingParams
|
||||
extends Pick<TrendParams, 'start_date' | 'end_date'> {
|
||||
limit?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user usage trend data
|
||||
* @param params - Query parameters for filtering
|
||||
@@ -213,6 +219,20 @@ export async function getUserUsageTrend(params?: UserTrendParams): Promise<UserT
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user spending ranking data
|
||||
* @param params - Query parameters for filtering
|
||||
* @returns User spending ranking data
|
||||
*/
|
||||
export async function getUserSpendingRanking(
|
||||
params?: UserSpendingRankingParams
|
||||
): Promise<UserSpendingRankingResponse> {
|
||||
const { data } = await apiClient.get<UserSpendingRankingResponse>('/admin/dashboard/users-ranking', {
|
||||
params
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
export interface BatchUserUsageStats {
|
||||
user_id: number
|
||||
today_actual_cost: number
|
||||
@@ -271,6 +291,7 @@ export const dashboardAPI = {
|
||||
getSnapshotV2,
|
||||
getApiKeyUsageTrend,
|
||||
getUserUsageTrend,
|
||||
getUserSpendingRanking,
|
||||
getBatchUsersUsage,
|
||||
getBatchApiKeysUsage
|
||||
}
|
||||
|
||||
@@ -232,7 +232,7 @@
|
||||
<!-- Account Type Selection (Anthropic) -->
|
||||
<div v-if="form.platform === 'anthropic'">
|
||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||
<div class="mt-2 grid grid-cols-2 gap-3" data-tour="account-form-type">
|
||||
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'oauth-based'"
|
||||
@@ -292,6 +292,66 @@
|
||||
}}</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'bedrock'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'bedrock'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
accountCategory === 'bedrock'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="cloud" size="sm" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
|
||||
t('admin.accounts.bedrockLabel')
|
||||
}}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{
|
||||
t('admin.accounts.bedrockDesc')
|
||||
}}</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'bedrock-apikey'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'bedrock-apikey'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
accountCategory === 'bedrock-apikey'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="key" size="sm" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
|
||||
t('admin.accounts.bedrockApiKeyLabel')
|
||||
}}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{
|
||||
t('admin.accounts.bedrockApiKeyDesc')
|
||||
}}</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -896,7 +956,7 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && accountCategory !== 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
||||
<input
|
||||
@@ -1279,6 +1339,289 @@
|
||||
|
||||
</div>
|
||||
|
||||
<!-- Bedrock credentials (only for Anthropic Bedrock type) -->
|
||||
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="bedrockAccessKeyId"
|
||||
type="text"
|
||||
required
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="bedrockSecretAccessKey"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="bedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<select v-model="bedrockRegion" class="input">
|
||||
<optgroup label="US">
|
||||
<option value="us-east-1">us-east-1 (N. Virginia)</option>
|
||||
<option value="us-east-2">us-east-2 (Ohio)</option>
|
||||
<option value="us-west-1">us-west-1 (N. California)</option>
|
||||
<option value="us-west-2">us-west-2 (Oregon)</option>
|
||||
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
|
||||
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Europe">
|
||||
<option value="eu-west-1">eu-west-1 (Ireland)</option>
|
||||
<option value="eu-west-2">eu-west-2 (London)</option>
|
||||
<option value="eu-west-3">eu-west-3 (Paris)</option>
|
||||
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
|
||||
<option value="eu-central-2">eu-central-2 (Zurich)</option>
|
||||
<option value="eu-south-1">eu-south-1 (Milan)</option>
|
||||
<option value="eu-south-2">eu-south-2 (Spain)</option>
|
||||
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Asia Pacific">
|
||||
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
|
||||
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
|
||||
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
|
||||
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
|
||||
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
|
||||
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
|
||||
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Canada">
|
||||
<option value="ca-central-1">ca-central-1 (Canada)</option>
|
||||
</optgroup>
|
||||
<optgroup label="South America">
|
||||
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="bedrockForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section for Bedrock -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock API Key credentials (only for Anthropic Bedrock API Key type) -->
|
||||
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="bedrockApiKeyValue"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<select v-model="bedrockApiKeyRegion" class="input">
|
||||
<optgroup label="US">
|
||||
<option value="us-east-1">us-east-1 (N. Virginia)</option>
|
||||
<option value="us-east-2">us-east-2 (Ohio)</option>
|
||||
<option value="us-west-1">us-west-1 (N. California)</option>
|
||||
<option value="us-west-2">us-west-2 (Oregon)</option>
|
||||
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
|
||||
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Europe">
|
||||
<option value="eu-west-1">eu-west-1 (Ireland)</option>
|
||||
<option value="eu-west-2">eu-west-2 (London)</option>
|
||||
<option value="eu-west-3">eu-west-3 (Paris)</option>
|
||||
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
|
||||
<option value="eu-central-2">eu-central-2 (Zurich)</option>
|
||||
<option value="eu-south-1">eu-south-1 (Milan)</option>
|
||||
<option value="eu-south-2">eu-south-2 (Spain)</option>
|
||||
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Asia Pacific">
|
||||
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
|
||||
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
|
||||
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
|
||||
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
|
||||
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
|
||||
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
|
||||
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Canada">
|
||||
<option value="ca-central-1">ca-central-1 (Canada)</option>
|
||||
</optgroup>
|
||||
<optgroup label="South America">
|
||||
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="bedrockApiKeyForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section for Bedrock API Key -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="form.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
@@ -2671,7 +3014,7 @@ interface TempUnschedRuleForm {
|
||||
// State
|
||||
const step = ref(1)
|
||||
const submitting = ref(false)
|
||||
const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selection for account category
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category
|
||||
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||
const apiKeyValue = ref('')
|
||||
@@ -2704,6 +3047,19 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist'
|
||||
const antigravityWhitelistModels = ref<string[]>([])
|
||||
const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
||||
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
|
||||
|
||||
// Bedrock credentials
|
||||
const bedrockAccessKeyId = ref('')
|
||||
const bedrockSecretAccessKey = ref('')
|
||||
const bedrockSessionToken = ref('')
|
||||
const bedrockRegion = ref('us-east-1')
|
||||
const bedrockForceGlobal = ref(false)
|
||||
|
||||
// Bedrock API Key credentials
|
||||
const bedrockApiKeyValue = ref('')
|
||||
const bedrockApiKeyRegion = ref('us-east-1')
|
||||
const bedrockApiKeyForceGlobal = ref(false)
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||
@@ -2868,6 +3224,10 @@ const isOAuthFlow = computed(() => {
|
||||
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
|
||||
return false
|
||||
}
|
||||
// Bedrock 类型不需要 OAuth 流程
|
||||
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') {
|
||||
return false
|
||||
}
|
||||
return accountCategory.value === 'oauth-based'
|
||||
})
|
||||
|
||||
@@ -2935,6 +3295,11 @@ watch(
|
||||
form.type = 'apikey'
|
||||
return
|
||||
}
|
||||
// Bedrock 类型
|
||||
if (form.platform === 'anthropic' && category === 'bedrock') {
|
||||
form.type = 'bedrock' as AccountType
|
||||
return
|
||||
}
|
||||
if (category === 'oauth-based') {
|
||||
form.type = method as AccountType // 'oauth' or 'setup-token'
|
||||
} else {
|
||||
@@ -2972,6 +3337,13 @@ watch(
|
||||
antigravityModelMappings.value = []
|
||||
antigravityModelRestrictionMode.value = 'mapping'
|
||||
}
|
||||
// Reset Bedrock fields when switching platforms
|
||||
bedrockAccessKeyId.value = ''
|
||||
bedrockSecretAccessKey.value = ''
|
||||
bedrockSessionToken.value = ''
|
||||
bedrockRegion.value = 'us-east-1'
|
||||
bedrockForceGlobal.value = false
|
||||
bedrockApiKeyForceGlobal.value = false
|
||||
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||
interceptWarmupRequests.value = false
|
||||
@@ -3541,6 +3913,84 @@ const handleSubmit = async () => {
|
||||
return
|
||||
}
|
||||
|
||||
// For Bedrock type, create directly
|
||||
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') {
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!bedrockAccessKeyId.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockSecretAccessKey.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockRegion.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockRegionRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
aws_access_key_id: bedrockAccessKeyId.value.trim(),
|
||||
aws_secret_access_key: bedrockSecretAccessKey.value.trim(),
|
||||
aws_region: bedrockRegion.value.trim(),
|
||||
}
|
||||
if (bedrockSessionToken.value.trim()) {
|
||||
credentials.aws_session_token = bedrockSessionToken.value.trim()
|
||||
}
|
||||
if (bedrockForceGlobal.value) {
|
||||
credentials.aws_force_global = 'true'
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(
|
||||
modelRestrictionMode.value, allowedModels.value, modelMappings.value
|
||||
)
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For Bedrock API Key type, create directly
|
||||
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') {
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!bedrockApiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
api_key: bedrockApiKeyValue.value.trim(),
|
||||
aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1',
|
||||
}
|
||||
if (bedrockApiKeyForceGlobal.value) {
|
||||
credentials.aws_force_global = 'true'
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(
|
||||
modelRestrictionMode.value, allowedModels.value, modelMappings.value
|
||||
)
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For Antigravity upstream type, create directly
|
||||
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
|
||||
if (!form.name.trim()) {
|
||||
|
||||
@@ -563,6 +563,233 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock fields (only for bedrock type) -->
|
||||
<div v-if="account.type === 'bedrock'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="editBedrockAccessKeyId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSecretAccessKey"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<input
|
||||
v-model="editBedrockRegion"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder="us-east-1"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="editBedrockForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction for Bedrock -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="modelMappings.push({ from: preset.from, to: preset.to })"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock API Key fields (only for bedrock-apikey type) -->
|
||||
<div v-if="account.type === 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="editBedrockApiKeyValue"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<input
|
||||
v-model="editBedrockApiKeyRegion"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder="us-east-1"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="editBedrockApiKeyForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction for Bedrock API Key -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="modelMappings.push({ from: preset.from, to: preset.to })"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Antigravity model restriction (applies to all antigravity types) -->
|
||||
<!-- Antigravity 只支持模型映射模式,不支持白名单模式 -->
|
||||
<div v-if="account.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
@@ -1529,6 +1756,7 @@ const baseUrlHint = computed(() => {
|
||||
})
|
||||
|
||||
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
||||
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
|
||||
|
||||
// Model mapping type
|
||||
interface ModelMapping {
|
||||
@@ -1547,6 +1775,17 @@ interface TempUnschedRuleForm {
|
||||
const submitting = ref(false)
|
||||
const editBaseUrl = ref('https://api.anthropic.com')
|
||||
const editApiKey = ref('')
|
||||
// Bedrock credentials
|
||||
const editBedrockAccessKeyId = ref('')
|
||||
const editBedrockSecretAccessKey = ref('')
|
||||
const editBedrockSessionToken = ref('')
|
||||
const editBedrockRegion = ref('')
|
||||
const editBedrockForceGlobal = ref(false)
|
||||
|
||||
// Bedrock API Key credentials
|
||||
const editBedrockApiKeyValue = ref('')
|
||||
const editBedrockApiKeyRegion = ref('')
|
||||
const editBedrockApiKeyForceGlobal = ref(false)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
@@ -1889,6 +2128,58 @@ watch(
|
||||
} else {
|
||||
selectedErrorCodes.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'bedrock' && newAccount.credentials) {
|
||||
const bedrockCreds = newAccount.credentials as Record<string, unknown>
|
||||
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
|
||||
editBedrockRegion.value = (bedrockCreds.aws_region as string) || ''
|
||||
editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true'
|
||||
editBedrockSecretAccessKey.value = ''
|
||||
editBedrockSessionToken.value = ''
|
||||
|
||||
// Load model mappings for bedrock
|
||||
const existingMappings = bedrockCreds.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
const entries = Object.entries(existingMappings)
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
if (isWhitelistMode) {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = entries.map(([from]) => from)
|
||||
modelMappings.value = []
|
||||
} else {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) {
|
||||
const bedrockApiKeyCreds = newAccount.credentials as Record<string, unknown>
|
||||
editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1'
|
||||
editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true'
|
||||
editBedrockApiKeyValue.value = ''
|
||||
|
||||
// Load model mappings for bedrock-apikey
|
||||
const existingMappings = bedrockApiKeyCreds.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
const entries = Object.entries(existingMappings)
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
if (isWhitelistMode) {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = entries.map(([from]) => from)
|
||||
modelMappings.value = []
|
||||
} else {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
|
||||
const credentials = newAccount.credentials as Record<string, unknown>
|
||||
editBaseUrl.value = (credentials.base_url as string) || ''
|
||||
@@ -2431,6 +2722,70 @@ const handleSubmit = async () => {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else if (props.account.type === 'bedrock') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
|
||||
newCredentials.aws_region = editBedrockRegion.value.trim()
|
||||
if (editBedrockForceGlobal.value) {
|
||||
newCredentials.aws_force_global = 'true'
|
||||
} else {
|
||||
delete newCredentials.aws_force_global
|
||||
}
|
||||
|
||||
// Only update secrets if user provided new values
|
||||
if (editBedrockSecretAccessKey.value.trim()) {
|
||||
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
|
||||
}
|
||||
if (editBedrockSessionToken.value.trim()) {
|
||||
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
} else {
|
||||
delete newCredentials.model_mapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else if (props.account.type === 'bedrock-apikey') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1'
|
||||
if (editBedrockApiKeyForceGlobal.value) {
|
||||
newCredentials.aws_force_global = 'true'
|
||||
} else {
|
||||
delete newCredentials.aws_force_global
|
||||
}
|
||||
|
||||
// Only update API key if user provided new value
|
||||
if (editBedrockApiKeyValue.value.trim()) {
|
||||
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
} else {
|
||||
delete newCredentials.model_mapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else {
|
||||
// For oauth/setup-token types, only update intercept_warmup_requests if changed
|
||||
|
||||
@@ -2,38 +2,72 @@
|
||||
<div class="card p-4">
|
||||
<div class="mb-4 flex items-center justify-between gap-3">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white">
|
||||
{{ t('admin.dashboard.modelDistribution') }}
|
||||
{{ !enableRankingView || activeView === 'model_distribution'
|
||||
? t('admin.dashboard.modelDistribution')
|
||||
: t('admin.dashboard.spendingRankingTitle') }}
|
||||
</h3>
|
||||
<div
|
||||
v-if="showMetricToggle"
|
||||
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="metric === 'tokens'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:metric', 'tokens')"
|
||||
<div class="flex items-center gap-2">
|
||||
<div
|
||||
v-if="showMetricToggle"
|
||||
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
|
||||
>
|
||||
{{ t('admin.dashboard.metricTokens') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="metric === 'actual_cost'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:metric', 'actual_cost')"
|
||||
>
|
||||
{{ t('admin.dashboard.metricActualCost') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="metric === 'tokens'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:metric', 'tokens')"
|
||||
>
|
||||
{{ t('admin.dashboard.metricTokens') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="metric === 'actual_cost'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:metric', 'actual_cost')"
|
||||
>
|
||||
{{ t('admin.dashboard.metricActualCost') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="enableRankingView" class="inline-flex rounded-lg bg-gray-100 p-1 dark:bg-dark-800">
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="
|
||||
activeView === 'model_distribution'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'
|
||||
"
|
||||
@click="activeView = 'model_distribution'"
|
||||
>
|
||||
{{ t('admin.dashboard.viewModelDistribution') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="
|
||||
activeView === 'spending_ranking'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'
|
||||
"
|
||||
@click="activeView = 'spending_ranking'"
|
||||
>
|
||||
{{ t('admin.dashboard.viewSpendingRanking') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="loading" class="flex h-48 items-center justify-center">
|
||||
|
||||
<div v-if="activeView === 'model_distribution' && loading" class="flex h-48 items-center justify-center">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
<div v-else-if="displayModelStats.length > 0 && chartData" class="flex items-center gap-6">
|
||||
<div
|
||||
v-else-if="activeView === 'model_distribution' && displayModelStats.length > 0 && chartData"
|
||||
class="flex items-center gap-6"
|
||||
>
|
||||
<div class="h-48 w-48">
|
||||
<Doughnut :data="chartData" :options="doughnutOptions" />
|
||||
</div>
|
||||
@@ -77,6 +111,70 @@
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-else-if="activeView === 'model_distribution'"
|
||||
class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400"
|
||||
>
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
|
||||
<div v-else-if="rankingLoading" class="flex h-48 items-center justify-center">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
<div
|
||||
v-else-if="rankingError"
|
||||
class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400"
|
||||
>
|
||||
{{ t('admin.dashboard.failedToLoad') }}
|
||||
</div>
|
||||
<div v-else-if="rankingItems.length > 0 && rankingChartData" class="flex items-center gap-6">
|
||||
<div class="h-48 w-48">
|
||||
<Doughnut :data="rankingChartData" :options="rankingDoughnutOptions" />
|
||||
</div>
|
||||
<div class="max-h-48 flex-1 overflow-y-auto">
|
||||
<table class="w-full text-xs">
|
||||
<thead>
|
||||
<tr class="text-gray-500 dark:text-gray-400">
|
||||
<th class="pb-2 text-left">{{ t('admin.dashboard.spendingRankingUser') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.spendingRankingRequests') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.spendingRankingTokens') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.spendingRankingSpend') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="(item, index) in rankingItems"
|
||||
:key="`${item.user_id}-${index}`"
|
||||
class="cursor-pointer border-t border-gray-100 transition-colors hover:bg-gray-50 dark:border-gray-700 dark:hover:bg-dark-700/40"
|
||||
@click="emit('ranking-click', item)"
|
||||
>
|
||||
<td class="py-1.5">
|
||||
<div class="flex min-w-0 items-center gap-2">
|
||||
<span class="shrink-0 text-[11px] font-semibold text-gray-500 dark:text-gray-400">
|
||||
#{{ index + 1 }}
|
||||
</span>
|
||||
<span
|
||||
class="block max-w-[140px] truncate font-medium text-gray-900 dark:text-white"
|
||||
:title="getRankingUserLabel(item)"
|
||||
>
|
||||
{{ getRankingUserLabel(item) }}
|
||||
</span>
|
||||
</div>
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">
|
||||
{{ formatNumber(item.requests) }}
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">
|
||||
{{ formatTokens(item.tokens) }}
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-green-600 dark:text-green-400">
|
||||
${{ formatCost(item.actual_cost) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400"
|
||||
@@ -87,34 +185,47 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { computed, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { Chart as ChartJS, ArcElement, Tooltip, Legend } from 'chart.js'
|
||||
import { Doughnut } from 'vue-chartjs'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import type { ModelStat } from '@/types'
|
||||
import type { ModelStat, UserSpendingRankingItem } from '@/types'
|
||||
|
||||
ChartJS.register(ArcElement, Tooltip, Legend)
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
modelStats: ModelStat[]
|
||||
enableRankingView?: boolean
|
||||
rankingItems?: UserSpendingRankingItem[]
|
||||
rankingTotalActualCost?: number
|
||||
loading?: boolean
|
||||
metric?: DistributionMetric
|
||||
showMetricToggle?: boolean
|
||||
rankingLoading?: boolean
|
||||
rankingError?: boolean
|
||||
}>(), {
|
||||
enableRankingView: false,
|
||||
rankingItems: () => [],
|
||||
rankingTotalActualCost: 0,
|
||||
loading: false,
|
||||
metric: 'tokens',
|
||||
showMetricToggle: false,
|
||||
rankingLoading: false,
|
||||
rankingError: false
|
||||
})
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:metric': [value: DistributionMetric]
|
||||
'ranking-click': [item: UserSpendingRankingItem]
|
||||
}>()
|
||||
|
||||
const enableRankingView = computed(() => props.enableRankingView)
|
||||
const activeView = ref<'model_distribution' | 'spending_ranking'>('model_distribution')
|
||||
|
||||
const chartColors = [
|
||||
'#3b82f6',
|
||||
'#10b981',
|
||||
@@ -125,7 +236,9 @@ const chartColors = [
|
||||
'#14b8a6',
|
||||
'#f97316',
|
||||
'#6366f1',
|
||||
'#84cc16'
|
||||
'#84cc16',
|
||||
'#06b6d4',
|
||||
'#a855f7'
|
||||
]
|
||||
|
||||
const displayModelStats = computed(() => {
|
||||
@@ -150,6 +263,31 @@ const chartData = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
const rankingChartData = computed(() => {
|
||||
if (!props.rankingItems?.length) return null
|
||||
|
||||
const rankedTotal = props.rankingItems.reduce((sum, item) => sum + item.actual_cost, 0)
|
||||
const otherActualCost = Math.max((props.rankingTotalActualCost || 0) - rankedTotal, 0)
|
||||
const labels = props.rankingItems.map((item, index) => `#${index + 1} ${getRankingUserLabel(item)}`)
|
||||
const data = props.rankingItems.map((item) => item.actual_cost)
|
||||
|
||||
if (otherActualCost > 0.000001) {
|
||||
labels.push(t('admin.dashboard.spendingRankingOther'))
|
||||
data.push(otherActualCost)
|
||||
}
|
||||
|
||||
return {
|
||||
labels,
|
||||
datasets: [
|
||||
{
|
||||
data,
|
||||
backgroundColor: chartColors.slice(0, data.length),
|
||||
borderWidth: 0
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
const doughnutOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
@@ -173,6 +311,26 @@ const doughnutOptions = computed(() => ({
|
||||
}
|
||||
}))
|
||||
|
||||
const rankingDoughnutOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
display: false
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: (context: any) => {
|
||||
const value = context.raw as number
|
||||
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
||||
const percentage = total > 0 ? ((value / total) * 100).toFixed(1) : '0.0'
|
||||
return `${context.label}: $${formatCost(value)} (${percentage}%)`
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
const formatTokens = (value: number): string => {
|
||||
if (value >= 1_000_000_000) {
|
||||
return `${(value / 1_000_000_000).toFixed(2)}B`
|
||||
@@ -188,6 +346,11 @@ const formatNumber = (value: number): string => {
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const getRankingUserLabel = (item: UserSpendingRankingItem): string => {
|
||||
if (item.email) return item.email
|
||||
return t('admin.redeem.userPrefix', { id: item.user_id })
|
||||
}
|
||||
|
||||
const formatCost = (value: number): string => {
|
||||
if (value >= 1000) {
|
||||
return (value / 1000).toFixed(2) + 'K'
|
||||
|
||||
@@ -82,6 +82,8 @@ const typeLabel = computed(() => {
|
||||
return 'Token'
|
||||
case 'apikey':
|
||||
return 'Key'
|
||||
case 'bedrock':
|
||||
return 'Bedrock'
|
||||
default:
|
||||
return props.type
|
||||
}
|
||||
|
||||
@@ -331,6 +331,15 @@ const antigravityPresetMappings = [
|
||||
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
|
||||
]
|
||||
|
||||
// Bedrock 预设映射(与后端 DefaultBedrockModelMapping 保持一致)
|
||||
const bedrockPresetMappings = [
|
||||
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||
{ label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||
{ label: 'Haiku 4.5', from: 'claude-haiku-4-5', to: 'us.anthropic.claude-haiku-4-5-20251001-v1:0', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||
]
|
||||
|
||||
// Antigravity 默认映射(从后端 API 获取,与 constants.go 保持一致)
|
||||
// 使用 fetchAntigravityDefaultMappings() 异步获取
|
||||
import { getAntigravityDefaultModelMapping } from '@/api/admin/accounts'
|
||||
@@ -403,6 +412,7 @@ export function getPresetMappingsByPlatform(platform: string) {
|
||||
if (platform === 'gemini') return geminiPresetMappings
|
||||
if (platform === 'sora') return soraPresetMappings
|
||||
if (platform === 'antigravity') return antigravityPresetMappings
|
||||
if (platform === 'bedrock' || platform === 'bedrock-apikey') return bedrockPresetMappings
|
||||
return anthropicPresetMappings
|
||||
}
|
||||
|
||||
|
||||
@@ -963,6 +963,18 @@ export default {
|
||||
standard: 'Standard',
|
||||
noDataAvailable: 'No data available',
|
||||
recentUsage: 'Recent Usage',
|
||||
viewModelDistribution: 'Model Distribution',
|
||||
viewSpendingRanking: 'User Spending Ranking',
|
||||
spendingRankingTitle: 'User Spending Ranking',
|
||||
spendingRankingUser: 'User',
|
||||
spendingRankingRequests: 'Requests',
|
||||
spendingRankingTokens: 'Tokens',
|
||||
spendingRankingSpend: 'Spend',
|
||||
spendingRankingOther: 'Others',
|
||||
spendingRankingUsage: 'Usage',
|
||||
spendShort: 'Spend',
|
||||
requestsShort: 'Req',
|
||||
tokensShort: 'Tok',
|
||||
failedToLoad: 'Failed to load dashboard statistics'
|
||||
},
|
||||
|
||||
@@ -1921,6 +1933,8 @@ export default {
|
||||
accountType: 'Account Type',
|
||||
claudeCode: 'Claude Code',
|
||||
claudeConsole: 'Claude Console',
|
||||
bedrockLabel: 'AWS Bedrock',
|
||||
bedrockDesc: 'SigV4 Signing',
|
||||
oauthSetupToken: 'OAuth / Setup Token',
|
||||
addMethod: 'Add Method',
|
||||
setupTokenLongLived: 'Setup Token (Long-lived)',
|
||||
@@ -2110,6 +2124,23 @@ export default {
|
||||
mixedChannelWarning: 'Warning: Group "{groupName}" contains both {currentPlatform} and {otherPlatform} accounts. Mixing different channels may cause thinking block signature validation issues, which will fallback to non-thinking mode. Are you sure you want to continue?',
|
||||
pleaseEnterAccountName: 'Please enter account name',
|
||||
pleaseEnterApiKey: 'Please enter API Key',
|
||||
bedrockAccessKeyId: 'AWS Access Key ID',
|
||||
bedrockSecretAccessKey: 'AWS Secret Access Key',
|
||||
bedrockSessionToken: 'AWS Session Token',
|
||||
bedrockRegion: 'AWS Region',
|
||||
bedrockRegionHint: 'e.g. us-east-1, us-west-2, eu-west-1',
|
||||
bedrockForceGlobal: 'Force Global cross-region inference',
|
||||
bedrockForceGlobalHint: 'When enabled, model IDs use the global. prefix (e.g. global.anthropic.claude-...), routing requests to any supported region worldwide for higher availability',
|
||||
bedrockAccessKeyIdRequired: 'Please enter AWS Access Key ID',
|
||||
bedrockSecretAccessKeyRequired: 'Please enter AWS Secret Access Key',
|
||||
bedrockRegionRequired: 'Please select AWS Region',
|
||||
bedrockSessionTokenHint: 'Optional, for temporary credentials',
|
||||
bedrockSecretKeyLeaveEmpty: 'Leave empty to keep current key',
|
||||
bedrockApiKeyLabel: 'Bedrock API Key',
|
||||
bedrockApiKeyDesc: 'Bearer Token',
|
||||
bedrockApiKeyInput: 'API Key',
|
||||
bedrockApiKeyRequired: 'Please enter Bedrock API Key',
|
||||
bedrockApiKeyLeaveEmpty: 'Leave empty to keep current key',
|
||||
apiKeyIsRequired: 'API Key is required',
|
||||
leaveEmptyToKeep: 'Leave empty to keep current key',
|
||||
// Upstream type
|
||||
|
||||
@@ -974,6 +974,18 @@ export default {
|
||||
tokens: 'Token',
|
||||
cache: '缓存',
|
||||
recentUsage: '最近使用',
|
||||
viewModelDistribution: '模型分布',
|
||||
viewSpendingRanking: '用户消费榜',
|
||||
spendingRankingTitle: '用户消费榜',
|
||||
spendingRankingUser: '用户',
|
||||
spendingRankingRequests: '请求',
|
||||
spendingRankingTokens: 'Token',
|
||||
spendingRankingSpend: '消费',
|
||||
spendingRankingOther: '其他',
|
||||
spendingRankingUsage: '用量',
|
||||
spendShort: '消费',
|
||||
requestsShort: '请求',
|
||||
tokensShort: 'Token',
|
||||
last7Days: '近 7 天',
|
||||
noUsageRecords: '暂无使用记录',
|
||||
startUsingApi: '开始使用 API 后,使用历史将显示在这里。',
|
||||
@@ -2069,6 +2081,8 @@ export default {
|
||||
accountType: '账号类型',
|
||||
claudeCode: 'Claude Code',
|
||||
claudeConsole: 'Claude Console',
|
||||
bedrockLabel: 'AWS Bedrock',
|
||||
bedrockDesc: 'SigV4 签名',
|
||||
oauthSetupToken: 'OAuth / Setup Token',
|
||||
addMethod: '添加方式',
|
||||
setupTokenLongLived: 'Setup Token(长期有效)',
|
||||
@@ -2251,6 +2265,23 @@ export default {
|
||||
mixedChannelWarning: '警告:分组 "{groupName}" 中同时包含 {currentPlatform} 和 {otherPlatform} 账号。混合使用不同渠道可能导致 thinking block 签名验证问题,会自动回退到非 thinking 模式。确定要继续吗?',
|
||||
pleaseEnterAccountName: '请输入账号名称',
|
||||
pleaseEnterApiKey: '请输入 API Key',
|
||||
bedrockAccessKeyId: 'AWS Access Key ID',
|
||||
bedrockSecretAccessKey: 'AWS Secret Access Key',
|
||||
bedrockSessionToken: 'AWS Session Token',
|
||||
bedrockRegion: 'AWS Region',
|
||||
bedrockRegionHint: '例如 us-east-1, us-west-2, eu-west-1',
|
||||
bedrockForceGlobal: '强制使用 Global 跨区域推理',
|
||||
bedrockForceGlobalHint: '启用后模型 ID 使用 global. 前缀(如 global.anthropic.claude-...),请求可路由到全球任意支持的区域,获得更高可用性',
|
||||
bedrockAccessKeyIdRequired: '请输入 AWS Access Key ID',
|
||||
bedrockSecretAccessKeyRequired: '请输入 AWS Secret Access Key',
|
||||
bedrockRegionRequired: '请选择 AWS Region',
|
||||
bedrockSessionTokenHint: '可选,用于临时凭证',
|
||||
bedrockSecretKeyLeaveEmpty: '留空以保持当前密钥',
|
||||
bedrockApiKeyLabel: 'Bedrock API Key',
|
||||
bedrockApiKeyDesc: 'Bearer Token 认证',
|
||||
bedrockApiKeyInput: 'API Key',
|
||||
bedrockApiKeyRequired: '请输入 Bedrock API Key',
|
||||
bedrockApiKeyLeaveEmpty: '留空以保持当前密钥',
|
||||
apiKeyIsRequired: 'API Key 是必需的',
|
||||
leaveEmptyToKeep: '留空以保持当前密钥',
|
||||
// Upstream type
|
||||
|
||||
@@ -531,7 +531,7 @@ export interface UpdateGroupRequest {
|
||||
// ==================== Account & Proxy Types ====================
|
||||
|
||||
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora'
|
||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream'
|
||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'bedrock-apikey'
|
||||
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
||||
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
|
||||
|
||||
@@ -1155,12 +1155,28 @@ export interface UserUsageTrendPoint {
|
||||
date: string
|
||||
user_id: number
|
||||
email: string
|
||||
username: string
|
||||
requests: number
|
||||
tokens: number
|
||||
cost: number // 标准计费
|
||||
actual_cost: number // 实际扣除
|
||||
}
|
||||
|
||||
export interface UserSpendingRankingItem {
|
||||
user_id: number
|
||||
email: string
|
||||
actual_cost: number
|
||||
requests: number
|
||||
tokens: number
|
||||
}
|
||||
|
||||
export interface UserSpendingRankingResponse {
|
||||
ranking: UserSpendingRankingItem[]
|
||||
total_actual_cost: number
|
||||
start_date: string
|
||||
end_date: string
|
||||
}
|
||||
|
||||
export interface ApiKeyUsageTrendPoint {
|
||||
date: string
|
||||
api_key_id: number
|
||||
|
||||
@@ -236,7 +236,16 @@
|
||||
|
||||
<!-- Charts Grid -->
|
||||
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
||||
<ModelDistributionChart :model-stats="modelStats" :loading="chartsLoading" />
|
||||
<ModelDistributionChart
|
||||
:model-stats="modelStats"
|
||||
:enable-ranking-view="true"
|
||||
:ranking-items="rankingItems"
|
||||
:ranking-total-actual-cost="rankingTotalActualCost"
|
||||
:loading="chartsLoading"
|
||||
:ranking-loading="rankingLoading"
|
||||
:ranking-error="rankingError"
|
||||
@ranking-click="goToUserUsage"
|
||||
/>
|
||||
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
|
||||
</div>
|
||||
|
||||
@@ -267,11 +276,18 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
|
||||
const { t } = useI18n()
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { DashboardStats, TrendDataPoint, ModelStat, UserUsageTrendPoint } from '@/types'
|
||||
import type {
|
||||
DashboardStats,
|
||||
TrendDataPoint,
|
||||
ModelStat,
|
||||
UserUsageTrendPoint,
|
||||
UserSpendingRankingItem
|
||||
} from '@/types'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
@@ -286,7 +302,6 @@ import {
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
@@ -299,39 +314,42 @@ ChartJS.register(
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
)
|
||||
|
||||
const appStore = useAppStore()
|
||||
const router = useRouter()
|
||||
const stats = ref<DashboardStats | null>(null)
|
||||
const loading = ref(false)
|
||||
const chartsLoading = ref(false)
|
||||
const userTrendLoading = ref(false)
|
||||
const rankingLoading = ref(false)
|
||||
const rankingError = ref(false)
|
||||
|
||||
// Chart data
|
||||
const trendData = ref<TrendDataPoint[]>([])
|
||||
const modelStats = ref<ModelStat[]>([])
|
||||
const userTrend = ref<UserUsageTrendPoint[]>([])
|
||||
const rankingItems = ref<UserSpendingRankingItem[]>([])
|
||||
const rankingTotalActualCost = ref(0)
|
||||
let chartLoadSeq = 0
|
||||
let usersTrendLoadSeq = 0
|
||||
let rankingLoadSeq = 0
|
||||
const rankingLimit = 12
|
||||
|
||||
// Helper function to format date in local timezone
|
||||
const formatLocalDate = (date: Date): string => {
|
||||
return `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}-${String(date.getDate()).padStart(2, '0')}`
|
||||
}
|
||||
|
||||
// Initialize date range immediately
|
||||
const now = new Date()
|
||||
const weekAgo = new Date(now)
|
||||
weekAgo.setDate(weekAgo.getDate() - 6)
|
||||
const getTodayLocalDate = () => formatLocalDate(new Date())
|
||||
|
||||
// Date range
|
||||
const granularity = ref<'day' | 'hour'>('day')
|
||||
const startDate = ref(formatLocalDate(weekAgo))
|
||||
const endDate = ref(formatLocalDate(now))
|
||||
const startDate = ref(getTodayLocalDate())
|
||||
const endDate = ref(getTodayLocalDate())
|
||||
|
||||
// Granularity options for Select component
|
||||
const granularityOptions = computed(() => [
|
||||
@@ -415,23 +433,29 @@ const lineOptions = computed(() => ({
|
||||
const userTrendChartData = computed(() => {
|
||||
if (!userTrend.value?.length) return null
|
||||
|
||||
// Extract display name from email (part before @)
|
||||
const getDisplayName = (email: string, userId: number): string => {
|
||||
if (email && email.includes('@')) {
|
||||
return email.split('@')[0]
|
||||
const getDisplayName = (point: UserUsageTrendPoint): string => {
|
||||
const username = point.username?.trim()
|
||||
if (username) {
|
||||
return username
|
||||
}
|
||||
return t('admin.redeem.userPrefix', { id: userId })
|
||||
|
||||
const email = point.email?.trim()
|
||||
if (email) {
|
||||
return email
|
||||
}
|
||||
|
||||
return t('admin.redeem.userPrefix', { id: point.user_id })
|
||||
}
|
||||
|
||||
// Group by user
|
||||
const userGroups = new Map<string, { name: string; data: Map<string, number> }>()
|
||||
// Group by user_id to avoid merging different users with the same display name
|
||||
const userGroups = new Map<number, { name: string; data: Map<string, number> }>()
|
||||
const allDates = new Set<string>()
|
||||
|
||||
userTrend.value.forEach((point) => {
|
||||
allDates.add(point.date)
|
||||
const key = getDisplayName(point.email, point.user_id)
|
||||
const key = point.user_id
|
||||
if (!userGroups.has(key)) {
|
||||
userGroups.set(key, { name: key, data: new Map() })
|
||||
userGroups.set(key, { name: getDisplayName(point), data: new Map() })
|
||||
}
|
||||
userGroups.get(key)!.data.set(point.date, point.tokens)
|
||||
})
|
||||
@@ -502,6 +526,17 @@ const formatDuration = (ms: number): string => {
|
||||
return `${Math.round(ms)}ms`
|
||||
}
|
||||
|
||||
const goToUserUsage = (item: UserSpendingRankingItem) => {
|
||||
void router.push({
|
||||
path: '/admin/usage',
|
||||
query: {
|
||||
user_id: String(item.user_id),
|
||||
start_date: startDate.value,
|
||||
end_date: endDate.value
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Date range change handler
|
||||
const onDateRangeChange = (range: {
|
||||
startDate: string
|
||||
@@ -582,14 +617,46 @@ const loadUsersTrend = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
const loadUserSpendingRanking = async () => {
|
||||
const currentSeq = ++rankingLoadSeq
|
||||
rankingLoading.value = true
|
||||
rankingError.value = false
|
||||
try {
|
||||
const response = await adminAPI.dashboard.getUserSpendingRanking({
|
||||
start_date: startDate.value,
|
||||
end_date: endDate.value,
|
||||
limit: rankingLimit
|
||||
})
|
||||
if (currentSeq !== rankingLoadSeq) return
|
||||
rankingItems.value = response.ranking || []
|
||||
rankingTotalActualCost.value = response.total_actual_cost || 0
|
||||
} catch (error) {
|
||||
if (currentSeq !== rankingLoadSeq) return
|
||||
console.error('Error loading user spending ranking:', error)
|
||||
rankingItems.value = []
|
||||
rankingTotalActualCost.value = 0
|
||||
rankingError.value = true
|
||||
} finally {
|
||||
if (currentSeq === rankingLoadSeq) {
|
||||
rankingLoading.value = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const loadDashboardStats = async () => {
|
||||
await loadDashboardSnapshot(true)
|
||||
void loadUsersTrend()
|
||||
await Promise.all([
|
||||
loadDashboardSnapshot(true),
|
||||
loadUsersTrend(),
|
||||
loadUserSpendingRanking()
|
||||
])
|
||||
}
|
||||
|
||||
const loadChartData = async () => {
|
||||
await loadDashboardSnapshot(false)
|
||||
void loadUsersTrend()
|
||||
await Promise.all([
|
||||
loadDashboardSnapshot(false),
|
||||
loadUsersTrend(),
|
||||
loadUserSpendingRanking()
|
||||
])
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
|
||||
@@ -89,6 +89,7 @@
|
||||
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { saveAs } from 'file-saver'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage'
|
||||
import { formatReasoningEffort } from '@/utils/format'
|
||||
import { resolveUsageRequestType, requestTypeToLegacyStream } from '@/utils/usageRequestType'
|
||||
@@ -104,7 +105,7 @@ import type { AdminUsageLog, TrendDataPoint, ModelStat, GroupStat, AdminUser } f
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||
|
||||
const route = useRoute()
|
||||
const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs = ref<AdminUsageLog[]>([]); const loading = ref(false); const exporting = ref(false)
|
||||
const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const groupStats = ref<GroupStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
|
||||
const modelDistributionMetric = ref<DistributionMetric>('tokens')
|
||||
@@ -135,11 +136,43 @@ const formatLD = (d: Date) => {
|
||||
const day = String(d.getDate()).padStart(2, '0')
|
||||
return `${year}-${month}-${day}`
|
||||
}
|
||||
const now = new Date(); const weekAgo = new Date(); weekAgo.setDate(weekAgo.getDate() - 6)
|
||||
const startDate = ref(formatLD(weekAgo)); const endDate = ref(formatLD(now))
|
||||
const getTodayLocalDate = () => formatLD(new Date())
|
||||
const startDate = ref(getTodayLocalDate()); const endDate = ref(getTodayLocalDate())
|
||||
const filters = ref<AdminUsageQueryParams>({ user_id: undefined, model: undefined, group_id: undefined, request_type: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
|
||||
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
|
||||
|
||||
const getSingleQueryValue = (value: string | null | Array<string | null> | undefined): string | undefined => {
|
||||
if (Array.isArray(value)) return value.find((item): item is string => typeof item === 'string' && item.length > 0)
|
||||
return typeof value === 'string' && value.length > 0 ? value : undefined
|
||||
}
|
||||
|
||||
const getNumericQueryValue = (value: string | null | Array<string | null> | undefined): number | undefined => {
|
||||
const raw = getSingleQueryValue(value)
|
||||
if (!raw) return undefined
|
||||
const parsed = Number(raw)
|
||||
return Number.isFinite(parsed) ? parsed : undefined
|
||||
}
|
||||
|
||||
const applyRouteQueryFilters = () => {
|
||||
const queryStartDate = getSingleQueryValue(route.query.start_date)
|
||||
const queryEndDate = getSingleQueryValue(route.query.end_date)
|
||||
const queryUserId = getNumericQueryValue(route.query.user_id)
|
||||
|
||||
if (queryStartDate) {
|
||||
startDate.value = queryStartDate
|
||||
}
|
||||
if (queryEndDate) {
|
||||
endDate.value = queryEndDate
|
||||
}
|
||||
|
||||
filters.value = {
|
||||
...filters.value,
|
||||
user_id: queryUserId,
|
||||
start_date: startDate.value,
|
||||
end_date: endDate.value
|
||||
}
|
||||
}
|
||||
|
||||
const loadLogs = async () => {
|
||||
abortController?.abort(); const c = new AbortController(); abortController = c; loading.value = true
|
||||
try {
|
||||
@@ -191,7 +224,7 @@ const loadChartData = async () => {
|
||||
}
|
||||
const applyFilters = () => { pagination.page = 1; loadLogs(); loadStats(); loadChartData() }
|
||||
const refreshData = () => { loadLogs(); loadStats(); loadChartData() }
|
||||
const resetFilters = () => { startDate.value = formatLD(weekAgo); endDate.value = formatLD(now); filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }; granularity.value = 'day'; applyFilters() }
|
||||
const resetFilters = () => { startDate.value = getTodayLocalDate(); endDate.value = getTodayLocalDate(); filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }; granularity.value = 'day'; applyFilters() }
|
||||
const handlePageChange = (p: number) => { pagination.page = p; loadLogs() }
|
||||
const handlePageSizeChange = (s: number) => { pagination.page_size = s; pagination.page = 1; loadLogs() }
|
||||
const cancelExport = () => exportAbortController?.abort()
|
||||
@@ -329,6 +362,7 @@ const handleColumnClickOutside = (event: MouseEvent) => {
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
applyRouteQueryFilters()
|
||||
loadLogs()
|
||||
loadStats()
|
||||
window.setTimeout(() => {
|
||||
|
||||
Reference in New Issue
Block a user