diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 4d4517d2..444e4e31 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemHandler := handler.NewRedeemHandler(redeemService) @@ -163,9 +164,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/go.mod b/backend/go.mod index 03637401..135cbd3e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -7,7 +7,7 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 - github.com/aws/aws-sdk-go-v2 v1.41.2 + github.com/aws/aws-sdk-go-v2 v1.41.3 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 @@ -66,7 +66,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect - github.com/aws/smithy-go v1.24.1 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/bdandy/go-errors v1.2.2 // indirect github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 993a1d54..324fe652 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= @@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8 github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index de876098..e90e56af 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -934,9 +934,10 @@ type DashboardAggregationConfig struct { // DashboardAggregationRetentionConfig 预聚合保留窗口 type DashboardAggregationRetentionConfig struct { - UsageLogsDays int `mapstructure:"usage_logs_days"` - HourlyDays int `mapstructure:"hourly_days"` - DailyDays int `mapstructure:"daily_days"` + UsageLogsDays int `mapstructure:"usage_logs_days"` + UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` } // UsageCleanupConfig 使用记录清理任务配置 @@ -1301,6 +1302,7 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.backfill_enabled", false) viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365) viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.recompute_days", 2) @@ -1758,6 +1760,12 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") } + if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") } @@ -1780,6 +1788,14 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") } + if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 && + c.DashboardAgg.Retention.UsageLogsDays > 0 && + c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 79fcc6d0..abb76549 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) } + if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 { + t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays) + } if cfg.DashboardAgg.Retention.HourlyDays != 180 { t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays) } @@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 }, wantErr: "dashboard_aggregation.retention.usage_logs_days", }, + { + name: "dashboard aggregation dedup retention", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageBillingDedupDays = 0 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, + { + name: "dashboard aggregation dedup retention smaller than usage logs", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageLogsDays = 30 + c.DashboardAgg.Retention.UsageBillingDedupDays = 29 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, { name: "dashboard aggregation disabled interval", mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 }, diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 8a6621a1..36d043b5 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -27,10 +27,12 @@ const ( // Account type constants const ( - AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = "apikey" // API Key类型账号 - AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) + AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) ) // Redeem type constants @@ -113,3 +115,27 @@ var DefaultAntigravityModelMapping = map[string]string{ "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", } + +// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射 +// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID +// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的 +// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等) +var DefaultBedrockModelMapping = map[string]string{ + // Claude Opus + "claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1", + "claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1", + "claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0", + "claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0", + // Claude Sonnet + "claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6", + "claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6", + "claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0", + // Claude Haiku + "claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0", + "claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0", +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 57c2dad1..c7ca0ca2 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -97,7 +97,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -116,7 +116,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index aa82b24f..cc4ef2d0 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct { UserIDs []int64 `json:"user_ids" binding:"required"` } +var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second) var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second) +func parseRankingLimit(raw string) int { + limit, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || limit <= 0 { + return 12 + } + if limit > 50 { + return 50 + } + return limit +} + +// GetUserSpendingRanking handles getting user spending ranking data. +// GET /api/v1/admin/dashboard/users-ranking +func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + limit := parseRankingLimit(c.DefaultQuery("limit", "12")) + + keyRaw, _ := json.Marshal(struct { + Start string `json:"start"` + End string `json:"end"` + Limit int `json:"limit"` + }{ + Start: startTime.UTC().Format(time.RFC3339), + End: endTime.UTC().Format(time.RFC3339), + Limit: limit, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit) + if err != nil { + response.Error(c, 500, "Failed to get user spending ranking") + return + } + + payload := gin.H{ + "ranking": ranking.Ranking, + "total_actual_cost": ranking.TotalActualCost, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), + } + dashboardUsersRankingCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) +} + // GetBatchUsersUsage handles getting usage stats for multiple users // POST /api/v1/admin/dashboard/users-usage func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 72af6b45..6b363bb5 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct { trendStream *bool modelRequestType *int16 modelStream *bool + rankingLimit int + ranking []usagestats.UserSpendingRankingItem + rankingTotal float64 } func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( @@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( return []usagestats.ModelStat{}, nil } +func (s *dashboardUsageRepoCapture) GetUserSpendingRanking( + ctx context.Context, + startTime, endTime time.Time, + limit int, +) (*usagestats.UserSpendingRankingResponse, error) { + s.rankingLimit = limit + return &usagestats.UserSpendingRankingResponse{ + Ranking: s.ranking, + TotalActualCost: s.rankingTotal, + }, nil +} + func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { gin.SetMode(gin.TestMode) dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) @@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng router := gin.New() router.GET("/admin/dashboard/trend", handler.GetUsageTrend) router.GET("/admin/dashboard/models", handler.GetModelStats) + router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking) return router } @@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } + +func TestDashboardUsersRankingLimitAndCache(t *testing.T) { + dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) + repo := &dashboardUsageRepoCapture{ + ranking: []usagestats.UserSpendingRankingItem{ + {UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300}, + }, + rankingTotal: 88.8, + } + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 50, repo.rankingLimit) + require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8") + require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) +} diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 0a932ee9..13ea88d9 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct { } // CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. +// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。 type CreateAndRedeemCodeRequest struct { - Code string `json:"code" binding:"required,min=3,max=128"` - Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` - Value float64 `json:"value" binding:"required,gt=0"` - UserID int64 `json:"user_id" binding:"required,gt=0"` - Notes string `json:"notes"` + Code string `json:"code" binding:"required,min=3,max=128"` + Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容) + Value float64 `json:"value" binding:"required,gt=0"` + UserID int64 `json:"user_id" binding:"required,gt=0"` + GroupID *int64 `json:"group_id"` // subscription 类型必填 + ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0 + Notes string `json:"notes"` } // List handles listing all redeem codes with pagination @@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { return } req.Code = strings.TrimSpace(req.Code) + // 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。 + // 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。 + if req.Type == "" { + req.Type = "balance" + } + + if req.Type == "subscription" { + if req.GroupID == nil { + response.BadRequest(c, "group_id is required for subscription type") + return + } + if req.ValidityDays <= 0 { + response.BadRequest(c, "validity_days must be greater than 0 for subscription type") + return + } + } executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { existing, err := h.redeemService.GetByCode(ctx, req.Code) @@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { } createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{ - Code: req.Code, - Type: req.Type, - Value: req.Value, - Status: service.StatusUnused, - Notes: req.Notes, + Code: req.Code, + Type: req.Type, + Value: req.Value, + Status: service.StatusUnused, + Notes: req.Notes, + GroupID: req.GroupID, + ValidityDays: req.ValidityDays, }) if createErr != nil { // Unique code race: if code now exists, use idempotent semantics by used_by. diff --git a/backend/internal/handler/admin/redeem_handler_test.go b/backend/internal/handler/admin/redeem_handler_test.go new file mode 100644 index 00000000..0d42f64f --- /dev/null +++ b/backend/internal/handler/admin/redeem_handler_test.go @@ -0,0 +1,135 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal) +// RedeemService so that CreateAndRedeem's nil guard passes and we can test the +// parameter-validation layer that runs before any service call. +func newCreateAndRedeemHandler() *RedeemHandler { + return &RedeemHandler{ + adminService: newStubAdminService(), + redeemService: &service.RedeemService{}, // non-nil to pass nil guard + } +} + +// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response +// status code. For cases that pass validation and proceed into the service layer, +// a panic may occur (because RedeemService internals are nil); this is expected +// and treated as "validation passed" (returns 0 to indicate panic). +func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + jsonBytes, err := json.Marshal(body) + require.NoError(t, err) + c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes)) + c.Request.Header.Set("Content-Type", "application/json") + + defer func() { + if r := recover(); r != nil { + // Panic means we passed validation and entered service layer (expected for minimal stub). + code = 0 + } + }() + handler.CreateAndRedeem(c) + return w.Code +} + +func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) { + // 不传 type 字段时应默认 balance,不触发 subscription 校验。 + // 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。 + h := newCreateAndRedeemHandler() + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-balance-default", + "value": 10.0, + "user_id": 1, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "omitting type should default to balance and pass validation") +} + +func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) { + h := newCreateAndRedeemHandler() + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-no-group", + "type": "subscription", + "value": 29.9, + "user_id": 1, + "validity_days": 30, + // group_id 缺失 + }) + + assert.Equal(t, http.StatusBadRequest, code) +} + +func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) { + groupID := int64(5) + h := newCreateAndRedeemHandler() + + cases := []struct { + name string + validityDays int + }{ + {"zero", 0}, + {"negative", -1}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-bad-days-" + tc.name, + "type": "subscription", + "value": 29.9, + "user_id": 1, + "group_id": groupID, + "validity_days": tc.validityDays, + }) + + assert.Equal(t, http.StatusBadRequest, code) + }) + } +} + +func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) { + groupID := int64(5) + h := newCreateAndRedeemHandler() + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-valid", + "type": "subscription", + "value": 29.9, + "user_id": 1, + "group_id": groupID, + "validity_days": 31, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "valid subscription params should pass validation") +} + +func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) { + h := newCreateAndRedeemHandler() + // balance 类型不传 group_id 和 validity_days,不应报 400 + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-balance-no-extras", + "type": "balance", + "value": 50.0, + "user_id": 1, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "balance type should not require group_id or validity_days") +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4441cf07..676ba0e1 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: currentAPIKey, - User: currentAPIKey.User, - Account: account, - Subscription: currentSubscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: account, + Subscription: currentSubscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 0c94d50b..6bcc0003 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // accountRepo (not used: scheduler snapshot hit) &fakeGroupRepo{group: group}, nil, // usageLogRepo + nil, // usageBillingRepo nil, // userRepo nil, // userSubRepo nil, // userGroupRateRepo diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 50af9c8f..9a16ff3a 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + requestPayloadHash := service.HashUsageRequestPayload(body) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, @@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 8567b52b..d23c7efe 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.submitUsageRecordTask(func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), + APIKeyService: h.apiKeyService, }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 30a761bd..dab17673 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac // newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, + accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 48c1e451..06abdf60 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, }); err != nil { logger.L().With( zap.String("component", "handler.sora_gateway.chat_completions"), diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 688c5d12..312c7511 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { return nil, nil } +func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, nil } @@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, + nil, testutil.StubGatewayCache{}, cfg, nil, diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 0ff24a1f..1a0ca5bb 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -189,6 +189,5 @@ var DefaultStopSequences = []string{ "<|user|>", "<|endoftext|>", "<|end_of_turn|>", - "[DONE]", "\n\nHuman:", } diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 8826c048..55a049d3 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -96,12 +96,28 @@ type UserUsageTrendPoint struct { Date string `json:"date"` UserID int64 `json:"user_id"` Email string `json:"email"` + Username string `json:"username"` Requests int64 `json:"requests"` Tokens int64 `json:"tokens"` Cost float64 `json:"cost"` // 标准计费 ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// UserSpendingRankingItem represents a user spending ranking row. +type UserSpendingRankingItem struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + ActualCost float64 `json:"actual_cost"` // 实际扣除 + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +// UserSpendingRankingResponse represents ranking rows plus total spend for the time range. +type UserSpendingRankingResponse struct { + Ranking []UserSpendingRankingItem `json:"ranking"` + TotalActualCost float64 `json:"total_actual_cost"` +} + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint struct { Date string `json:"date"` diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index 59bbd6a3..e82a73a3 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -17,6 +17,9 @@ type dashboardAggregationRepository struct { sql sqlExecutor } +const usageLogsCleanupBatchSize = 10000 +const usageBillingDedupCleanupBatchSize = 10000 + // NewDashboardAggregationRepository 创建仪表盘预聚合仓储。 func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository { if sqlDB == nil { @@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool { } func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { + if r == nil || r.sql == nil { + return nil + } loc := timezone.Location() startLocal := start.In(loc) endLocal := end.In(loc) @@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta dayEnd = dayEnd.Add(24 * time.Hour) } + if db, ok := r.sql.(*sql.DB); ok { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + txRepo := newDashboardAggregationRepositoryWithSQL(tx) + if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() + } + return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd) +} + +func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error { // 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。 if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { return err @@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c if isPartitioned { return r.dropUsageLogsPartitions(ctx, cutoff) } - _, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC()) - return err + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid + FROM usage_logs + WHERE created_at < $1 + LIMIT $2 + ) + DELETE FROM usage_logs + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageLogsCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageLogsCleanupBatchSize { + return nil + } + } +} + +func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid, request_id, api_key_id, request_fingerprint, created_at + FROM usage_billing_dedup + WHERE created_at < $1 + LIMIT $2 + ), archived AS ( + INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at) + SELECT request_id, api_key_id, request_fingerprint, created_at + FROM victims + ON CONFLICT (request_id, api_key_id) DO NOTHING + ) + DELETE FROM usage_billing_dedup + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageBillingDedupCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageBillingDedupCleanupBatchSize { + return nil + } + } } func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { diff --git a/backend/internal/repository/fixtures_integration_test.go b/backend/internal/repository/fixtures_integration_test.go index 23adb4e4..80b9cab6 100644 --- a/backend/internal/repository/fixtures_integration_test.go +++ b/backend/internal/repository/fixtures_integration_test.go @@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se SetKey(k.Key). SetName(k.Name). SetStatus(k.Status) + if k.Quota != 0 { + create.SetQuota(k.Quota) + } + if k.QuotaUsed != 0 { + create.SetQuotaUsed(k.QuotaUsed) + } + if k.RateLimit5h != 0 { + create.SetRateLimit5h(k.RateLimit5h) + } + if k.RateLimit1d != 0 { + create.SetRateLimit1d(k.RateLimit1d) + } + if k.RateLimit7d != 0 { + create.SetRateLimit7d(k.RateLimit7d) + } + if k.Usage5h != 0 { + create.SetUsage5h(k.Usage5h) + } + if k.Usage1d != 0 { + create.SetUsage1d(k.Usage1d) + } + if k.Usage7d != 0 { + create.SetUsage7d(k.Usage7d) + } + if k.Window5hStart != nil { + create.SetWindow5hStart(*k.Window5hStart) + } + if k.Window1dStart != nil { + create.SetWindow1dStart(*k.Window1dStart) + } + if k.Window7dStart != nil { + create.SetWindow7dStart(*k.Window7dStart) + } + if k.ExpiresAt != nil { + create.SetExpiresAt(*k.ExpiresAt) + } if k.GroupID != nil { create.SetGroupID(*k.GroupID) } diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 72422d18..dd3019bb 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) + // usage_billing_dedup: billing idempotency narrow table + var usageBillingDedupRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass)) + require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist") + requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key") + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin") + + var usageBillingDedupArchiveRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass)) + require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist") + requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey") + // settings table should exist var settingsRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) @@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } +func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { + t.Helper() + + var exists bool + err := tx.QueryRowContext(context.Background(), ` +SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = $1 + AND indexname = $2 +) +`, table, index).Scan(&exists) + require.NoError(t, err, "query pg_indexes for %s.%s", table, index) + require.True(t, exists, "expected index %s on %s", index, table) +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go new file mode 100644 index 00000000..b13cfeb8 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo.go @@ -0,0 +1,308 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageBillingRepository struct { + db *sql.DB +} + +func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository { + return &usageBillingRepository{db: sqlDB} +} + +func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) { + if cmd == nil { + return &service.UsageBillingApplyResult{}, nil + } + if r == nil || r.db == nil { + return nil, errors.New("usage billing repository db is nil") + } + + cmd.Normalize() + if cmd.RequestID == "" { + return nil, service.ErrUsageBillingRequestIDRequired + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + applied, err := r.claimUsageBillingKey(ctx, tx, cmd) + if err != nil { + return nil, err + } + if !applied { + return &service.UsageBillingApplyResult{Applied: false}, nil + } + + result := &service.UsageBillingApplyResult{Applied: true} + if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + tx = nil + return result, nil +} + +func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) { + var id int64 + err := tx.QueryRowContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint) + VALUES ($1, $2, $3) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id + `, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + var existingFingerprint string + if err := tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil { + return false, err + } + if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if err != nil { + return false, err + } + var archivedFingerprint string + err = tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup_archive + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint) + if err == nil { + if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return false, err + } + return true, nil +} + +func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error { + if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil { + if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil { + return err + } + } + + if cmd.BalanceCost > 0 { + if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil { + return err + } + } + + if cmd.APIKeyQuotaCost > 0 { + exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost) + if err != nil { + return err + } + result.APIKeyQuotaExhausted = exhausted + } + + if cmd.APIKeyRateLimitCost > 0 { + if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil { + return err + } + } + + if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) { + if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil { + return err + } + } + + return nil +} + +func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error { + const updateSQL = ` + UPDATE user_subscriptions us + SET + daily_usage_usd = us.daily_usage_usd + $1, + weekly_usage_usd = us.weekly_usage_usd + $1, + monthly_usage_usd = us.monthly_usage_usd + $1, + updated_at = NOW() + FROM groups g + WHERE us.id = $2 + AND us.deleted_at IS NULL + AND us.group_id = g.id + AND g.deleted_at IS NULL + ` + res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrSubscriptionNotFound +} + +func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE users + SET balance = balance - $1, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, amount, userID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrUserNotFound +} + +func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) { + var exhausted bool + err := tx.QueryRowContext(ctx, ` + UPDATE api_keys + SET quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 + AND status = $3 + AND quota_used < quota + AND quota_used + $1 >= quota + THEN $4 + ELSE status + END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota + `, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted) + if errors.Is(err, sql.ErrNoRows) { + return false, service.ErrAPIKeyNotFound + } + if err != nil { + return false, err + } + return exhausted, nil +} + +func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, cost, apiKeyID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAPIKeyNotFound + } + return nil +} + +func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error { + rows, err := tx.QueryContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + || jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1) + || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_daily_used', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, + 'quota_daily_start', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_weekly_used', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, + 'quota_weekly_start', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, accountID) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } else { + if err := rows.Err(); err != nil { + return err + } + return service.ErrAccountNotFound + } + if err := rows.Err(); err != nil { + return err + } + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { + logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) + return err + } + } + return nil +} diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go new file mode 100644 index 00000000..eda34cc9 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo_integration_test.go @@ -0,0 +1,279 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-" + uuid.NewString(), + Name: "billing", + Quota: 1, + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + BalanceCost: 1.25, + APIKeyQuotaCost: 1.25, + APIKeyRateLimitCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result1) + require.True(t, result1.Applied) + require.True(t, result1.APIKeyQuotaExhausted) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result2) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed)) + require.InDelta(t, 1.25, quotaUsed, 0.000001) + + var usage5h float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h)) + require.InDelta(t, 1.25, usage5h, 0.000001) + + var status string + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status)) + require.Equal(t, service.StatusAPIKeyQuotaExhausted, status) + + var dedupCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount)) + require.Equal(t, 1, dedupCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + group := mustCreateGroup(t, client, &service.Group{ + Name: "usage-billing-group-" + uuid.NewString(), + Platform: service.PlatformAnthropic, + SubscriptionType: service.SubscriptionTypeSubscription, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + GroupID: &group.ID, + Key: "sk-usage-billing-sub-" + uuid.NewString(), + Name: "billing-sub", + }) + subscription := mustCreateSubscription(t, client, &service.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: 0, + SubscriptionID: &subscription.ID, + SubscriptionCost: 2.5, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var dailyUsage float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage)) + require.InDelta(t, 2.5, dailyUsage, 0.000001) +} + +func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-conflict-" + uuid.NewString(), + Name: "billing-conflict", + }) + + requestID := uuid.NewString() + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + }) + require.NoError(t, err) + + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 2.50, + }) + require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict) +} + +func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-account-" + uuid.NewString(), + Name: "billing-account", + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-quota-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + Extra: map[string]any{ + "quota_limit": 100.0, + }, + }) + + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 3.5, + }) + require.NoError(t, err) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed)) + require.InDelta(t, 3.5, quotaUsed, 0.000001) +} + +func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { + ctx := context.Background() + repo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + oldRequestID := "dedup-old-" + uuid.NewString() + newRequestID := "dedup-new-" + uuid.NewString() + oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400) + newCreatedAt := time.Now().UTC().Add(-time.Hour) + + _, err := integrationDB.ExecContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at) + VALUES ($1, 1, $2, $3), ($4, 1, $5, $6) + `, + oldRequestID, strings.Repeat("a", 64), oldCreatedAt, + newRequestID, strings.Repeat("b", 64), newCreatedAt, + ) + require.NoError(t, err) + + require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + var oldCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount)) + require.Equal(t, 0, oldCount) + + var newCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount)) + require.Equal(t, 1, newCount) + + var archivedCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount)) + require.Equal(t, 1, archivedCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-archive-" + uuid.NewString(), + Name: "billing-archive", + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + _, err = integrationDB.ExecContext(ctx, ` + UPDATE usage_billing_dedup + SET created_at = $1 + WHERE request_id = $2 AND api_key_id = $3 + `, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID) + require.NoError(t, err) + require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c91a68e5..845f2cf0 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,10 +3,14 @@ package repository import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "os" + "strconv" "strings" + "sync" + "sync/atomic" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -15,15 +19,56 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + gocache "github.com/patrickmn/go-cache" ) const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" +var usageLogInsertArgTypes = [...]string{ + "bigint", + "bigint", + "bigint", + "text", + "text", + "bigint", + "bigint", + "integer", + "integer", + "integer", + "integer", + "integer", + "integer", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "smallint", + "smallint", + "boolean", + "boolean", + "integer", + "integer", + "text", + "text", + "integer", + "text", + "text", + "text", + "text", + "boolean", + "timestamptz", +} + // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ "hour": "YYYY-MM-DD HH24:00", @@ -43,15 +88,89 @@ func safeDateFormat(granularity string) string { type usageLogRepository struct { client *dbent.Client sql sqlExecutor + db *sql.DB + + createBatchOnce sync.Once + createBatchCh chan usageLogCreateRequest + bestEffortBatchOnce sync.Once + bestEffortBatchCh chan usageLogBestEffortRequest + bestEffortRecent *gocache.Cache } +const ( + usageLogCreateBatchMaxSize = 64 + usageLogCreateBatchWindow = 3 * time.Millisecond + usageLogCreateBatchQueueCap = 4096 + usageLogCreateCancelWait = 2 * time.Second + + usageLogBestEffortBatchMaxSize = 256 + usageLogBestEffortBatchWindow = 20 * time.Millisecond + usageLogBestEffortBatchQueueCap = 32768 + usageLogBestEffortRecentTTL = 30 * time.Second +) + +type usageLogCreateRequest struct { + log *service.UsageLog + prepared usageLogInsertPrepared + shared *usageLogCreateShared + resultCh chan usageLogCreateResult +} + +type usageLogCreateResult struct { + inserted bool + err error +} + +type usageLogBestEffortRequest struct { + prepared usageLogInsertPrepared + apiKeyID int64 + resultCh chan error +} + +type usageLogInsertPrepared struct { + createdAt time.Time + requestID string + rateMultiplier float64 + requestType int16 + args []any +} + +type usageLogBatchState struct { + ID int64 + CreatedAt time.Time +} + +type usageLogBatchRow struct { + RequestID string `json:"request_id"` + APIKeyID int64 `json:"api_key_id"` + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Inserted bool `json:"inserted"` +} + +type usageLogCreateShared struct { + state atomic.Int32 +} + +const ( + usageLogCreateStateQueued int32 = iota + usageLogCreateStateProcessing + usageLogCreateStateCompleted + usageLogCreateStateCanceled +) + func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository { return newUsageLogRepositoryWithSQL(client, sqlDB) } func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository { // 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。 - return &usageLogRepository{client: client, sql: sqlq} + repo := &usageLogRepository{client: client, sql: sqlq} + if db, ok := sqlq.(*sql.DB); ok { + repo.db = db + } + repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute) + return repo } // getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤) @@ -82,24 +201,72 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return false, nil } - // 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。 - // 无事务时回退到默认的 *sql.DB 执行器。 - sqlq := r.sql if tx := dbent.TxFromContext(ctx); tx != nil { - sqlq = tx.Client() + return r.createSingle(ctx, tx.Client(), log) } - - createdAt := log.CreatedAt - if createdAt.IsZero() { - createdAt = time.Now() - } - requestID := strings.TrimSpace(log.RequestID) + if requestID == "" { + return r.createSingle(ctx, r.sql, log) + } log.RequestID = requestID + return r.createBatched(ctx, log) +} - rateMultiplier := log.RateMultiplier - log.SyncRequestTypeAndLegacyFields() - requestType := int16(log.RequestType) +func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error { + if log == nil { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + _, err := r.createSingle(ctx, tx.Client(), log) + return err + } + if r.db == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + r.ensureBestEffortBatcher() + if r.bestEffortBatchCh == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + req := usageLogBestEffortRequest{ + prepared: prepareUsageLogInsert(log), + apiKeyID: log.APIKeyID, + resultCh: make(chan error, 1), + } + if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok { + if _, exists := r.bestEffortRecent.Get(key); exists { + return nil + } + } + + select { + case r.bestEffortBatchCh <- req: + case <-ctx.Done(): + return service.MarkUsageLogCreateDropped(ctx.Err()) + default: + return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")) + } + + select { + case err := <-req.resultCh: + return err + case <-ctx.Done(): + return service.MarkUsageLogCreateDropped(ctx.Err()) + } +} + +func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { + prepared := prepareUsageLogInsert(log) + if sqlq == nil { + sqlq = r.sql + } + if ctx != nil && ctx.Err() != nil { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } query := ` INSERT INTO usage_logs ( @@ -151,6 +318,779 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) RETURNING id, created_at ` + if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" { + selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" + if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { + return false, err + } + log.RateMultiplier = prepared.rateMultiplier + return false, nil + } else { + return false, err + } + } + log.RateMultiplier = prepared.rateMultiplier + return true, nil +} + +func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) { + if r.db == nil { + return r.createSingle(ctx, r.sql, log) + } + r.ensureCreateBatcher() + if r.createBatchCh == nil { + return r.createSingle(ctx, r.sql, log) + } + + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + + select { + case r.createBatchCh <- req: + case <-ctx.Done(): + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + default: + return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full")) + } + + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-ctx.Done(): + if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } + timer := time.NewTimer(usageLogCreateCancelWait) + defer timer.Stop() + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-timer.C: + return false, ctx.Err() + } + } +} + +func (r *usageLogRepository) ensureCreateBatcher() { + if r == nil || r.db == nil || r.createBatchCh != nil { + return + } + r.createBatchOnce.Do(func() { + r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap) + go r.runCreateBatcher(r.db) + }) +} + +func (r *usageLogRepository) ensureBestEffortBatcher() { + if r == nil || r.db == nil || r.bestEffortBatchCh != nil { + return + } + r.bestEffortBatchOnce.Do(func() { + r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap) + go r.runBestEffortBatcher(r.db) + }) +} + +func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { + for { + first, ok := <-r.createBatchCh + if !ok { + return + } + + batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogCreateBatchWindow) + batchLoop: + for len(batch) < usageLogCreateBatchMaxSize { + select { + case req, ok := <-r.createBatchCh: + if !ok { + break batchLoop + } + batch = append(batch, req) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushCreateBatch(db, batch) + } +} + +func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) { + for { + first, ok := <-r.bestEffortBatchCh + if !ok { + return + } + + batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogBestEffortBatchWindow) + bestEffortLoop: + for len(batch) < usageLogBestEffortBatchMaxSize { + select { + case req, ok := <-r.bestEffortBatchCh: + if !ok { + break bestEffortLoop + } + batch = append(batch, req) + case <-timer.C: + break bestEffortLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushBestEffortBatch(db, batch) + } +} + +func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { + if len(batch) == 0 { + return + } + + uniqueOrder := make([]string, 0, len(batch)) + preparedByKey := make(map[string]usageLogInsertPrepared, len(batch)) + requestsByKey := make(map[string][]usageLogCreateRequest, len(batch)) + fallback := make([]usageLogCreateRequest, 0) + + for _, req := range batch { + if req.log == nil { + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + continue + } + if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) { + if req.shared.state.Load() == usageLogCreateStateCanceled { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: service.MarkUsageLogCreateNotPersisted(context.Canceled), + }) + continue + } + } + prepared := req.prepared + if prepared.requestID == "" { + fallback = append(fallback, req) + continue + } + key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID) + if _, exists := requestsByKey[key]; !exists { + uniqueOrder = append(uniqueOrder, key) + preparedByKey[key] = prepared + } + requestsByKey[key] = append(requestsByKey[key], req) + } + + if len(uniqueOrder) > 0 { + insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + if err != nil { + if safeFallback { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, hasState := stateMap[key] + inserted := insertedMap[key] + for idx, req := range reqs { + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + if hasState { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + } + switch { + case inserted && idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil}) + case inserted: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case hasState: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err}) + default: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + } + } + } + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, ok := stateMap[key] + if !ok { + for _, req := range reqs { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: fmt.Errorf("usage log batch state missing for key=%s", key), + }) + } + continue + } + for idx, req := range reqs { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: idx == 0 && insertedMap[key], + err: nil, + }) + } + } + } + } + + if len(fallback) == 0 { + return + } + + for _, req := range fallback { + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + inserted, err := r.createSingle(fallbackCtx, db, req.log) + cancel() + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err}) + } +} + +func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) { + if len(batch) == 0 { + return + } + + type bestEffortGroup struct { + prepared usageLogInsertPrepared + apiKeyID int64 + key string + reqs []usageLogBestEffortRequest + } + + groupsByKey := make(map[string]*bestEffortGroup, len(batch)) + groupOrder := make([]*bestEffortGroup, 0, len(batch)) + preparedList := make([]usageLogInsertPrepared, 0, len(batch)) + + for idx, req := range batch { + prepared := req.prepared + key := fmt.Sprintf("__best_effort_%d", idx) + if prepared.requestID != "" { + key = usageLogBatchKey(prepared.requestID, req.apiKeyID) + } + group, exists := groupsByKey[key] + if !exists { + group = &bestEffortGroup{ + prepared: prepared, + apiKeyID: req.apiKeyID, + key: key, + } + groupsByKey[key] = group + groupOrder = append(groupOrder, group) + preparedList = append(preparedList, prepared) + } + group.reqs = append(group.reqs, req) + } + + if len(preparedList) == 0 { + for _, req := range batch { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBestEffortInsertQuery(preparedList) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err) + for _, group := range groupOrder { + singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared) + if singleErr != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr) + } else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, singleErr) + } + } + return + } + for _, group := range groupOrder { + if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + } +} + +func sendUsageLogBestEffortResult(ch chan error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + default: + } +} + +func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) { + if req.shared != nil { + req.shared.state.Store(usageLogCreateStateCompleted) + } + sendUsageLogCreateResult(req.resultCh, res) +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) { + if len(keys) == 0 { + return map[string]bool{}, map[string]usageLogBatchState{}, false, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) + var payload []byte + if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil { + return nil, nil, true, err + } + var rows []usageLogBatchRow + if err := json.Unmarshal(payload, &rows); err != nil { + return nil, nil, false, err + } + insertedMap := make(map[string]bool, len(keys)) + stateMap := make(map[string]usageLogBatchState, len(keys)) + for _, row := range rows { + key := usageLogBatchKey(row.RequestID, row.APIKeyID) + insertedMap[key] = row.Inserted + stateMap[key] = usageLogBatchState{ + ID: row.ID, + CreatedAt: row.CreatedAt, + } + } + if len(stateMap) != len(keys) { + return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys)) + } + return insertedMap, stateMap, false, nil +} + +func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + input_idx, + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(keys)*37) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + args = append(args, idx) + argPos++ + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + _, _ = query.WriteString(",") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ), + inserted AS ( + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + ), + resolved AS ( + SELECT + input.input_idx, + input.request_id, + input.api_key_id, + COALESCE(inserted.id, existing.id) AS id, + COALESCE(inserted.created_at, existing.created_at) AS created_at, + (inserted.id IS NOT NULL) AS inserted + FROM input + LEFT JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + LEFT JOIN usage_logs existing + ON existing.request_id = input.request_id + AND existing.api_key_id = input.api_key_id + ) + SELECT COALESCE( + json_agg( + json_build_object( + 'request_id', resolved.request_id, + 'api_key_id', resolved.api_key_id, + 'id', resolved.id, + 'created_at', resolved.created_at, + 'inserted', resolved.inserted + ) + ORDER BY resolved.input_idx + ), + '[]'::json + ) + FROM resolved + `) + return query.String(), args +} + +func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(preparedList)*36) + argPos := 1 + for idx, prepared := range preparedList { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + + _, _ = query.WriteString(` + ) + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO NOTHING + `) + + return query.String(), args +} + +func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error { + _, err := sqlq.ExecContext(ctx, ` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + `, prepared.args...) + return err +} + +func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { + createdAt := log.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now() + } + + requestID := strings.TrimSpace(log.RequestID) + log.RequestID = requestID + + rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) + groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) duration := nullInt(log.DurationMs) @@ -167,58 +1107,72 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) requestIDArg = requestID } - args := []any{ - log.UserID, - log.APIKeyID, - log.AccountID, - requestIDArg, - log.Model, - groupID, - subscriptionID, - log.InputTokens, - log.OutputTokens, - log.CacheCreationTokens, - log.CacheReadTokens, - log.CacheCreation5mTokens, - log.CacheCreation1hTokens, - log.InputCost, - log.OutputCost, - log.CacheCreationCost, - log.CacheReadCost, - log.TotalCost, - log.ActualCost, - rateMultiplier, - log.AccountRateMultiplier, - log.BillingType, - requestType, - log.Stream, - log.OpenAIWSMode, - duration, - firstToken, - userAgent, - ipAddress, - log.ImageCount, - imageSize, - mediaType, - serviceTier, - reasoningEffort, - log.CacheTTLOverridden, - createdAt, + return usageLogInsertPrepared{ + createdAt: createdAt, + requestID: requestID, + rateMultiplier: rateMultiplier, + requestType: requestType, + args: []any{ + log.UserID, + log.APIKeyID, + log.AccountID, + requestIDArg, + log.Model, + groupID, + subscriptionID, + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + rateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + requestType, + log.Stream, + log.OpenAIWSMode, + duration, + firstToken, + userAgent, + ipAddress, + log.ImageCount, + imageSize, + mediaType, + serviceTier, + reasoningEffort, + log.CacheTTLOverridden, + createdAt, + }, } - if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { - if errors.Is(err, sql.ErrNoRows) && requestID != "" { - selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" - if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { - return false, err - } - log.RateMultiplier = rateMultiplier - return false, nil - } else { - return false, err - } +} + +func usageLogBatchKey(requestID string, apiKeyID int64) string { + return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10) +} + +func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) { + if ch == nil { + return } - log.RateMultiplier = rateMultiplier - return true, nil + select { + case ch <- res: + default: + } +} + +func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) { + requestID = strings.TrimSpace(requestID) + if requestID == "" || r == nil || r.bestEffortRecent == nil { + return "", false + } + return usageLogBatchKey(requestID, apiKeyID), true } func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { @@ -1039,6 +1993,10 @@ type ModelStat = usagestats.ModelStat // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint = usagestats.UserUsageTrendPoint +// UserSpendingRankingItem represents a user spending ranking row. +type UserSpendingRankingItem = usagestats.UserSpendingRankingItem +type UserSpendingRankingResponse = usagestats.UserSpendingRankingResponse + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint @@ -1114,6 +2072,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e TO_CHAR(u.created_at, '%s') as date, u.user_id, COALESCE(us.email, '') as email, + COALESCE(us.username, '') as username, COUNT(*) as requests, COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens, COALESCE(SUM(u.total_cost), 0) as cost, @@ -1122,7 +2081,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e LEFT JOIN users us ON u.user_id = us.id WHERE u.user_id IN (SELECT user_id FROM top_users) AND u.created_at >= $4 AND u.created_at < $5 - GROUP BY date, u.user_id, us.email + GROUP BY date, u.user_id, us.email, us.username ORDER BY date ASC, tokens DESC `, dateFormat) @@ -1142,7 +2101,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e results = make([]UserUsageTrendPoint, 0) for rows.Next() { var row UserUsageTrendPoint - if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { + if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Username, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { return nil, err } results = append(results, row) @@ -1154,6 +2113,78 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e return results, nil } +// GetUserSpendingRanking returns user spending ranking aggregated within the time range. +func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (result *UserSpendingRankingResponse, err error) { + if limit <= 0 { + limit = 12 + } + + query := ` + WITH user_spend AS ( + SELECT + u.user_id, + COALESCE(us.email, '') as email, + COALESCE(SUM(u.actual_cost), 0) as actual_cost, + COUNT(*) as requests, + COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens + FROM usage_logs u + LEFT JOIN users us ON u.user_id = us.id + WHERE u.created_at >= $1 AND u.created_at < $2 + GROUP BY u.user_id, us.email + ), + ranked AS ( + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost + FROM user_spend + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + LIMIT $3 + ) + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + total_actual_cost + FROM ranked + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + ` + + rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + ranking := make([]UserSpendingRankingItem, 0) + totalActualCost := 0.0 + for rows.Next() { + var row UserSpendingRankingItem + if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil { + return nil, err + } + ranking = append(ranking, row) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return &UserSpendingRankingResponse{ + Ranking: ranking, + TotalActualCost: totalActualCost, + }, nil +} + // UserDashboardStats 用户仪表盘统计 type UserDashboardStats = usagestats.UserDashboardStats diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 4d50f7de..0383f3bc 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -4,6 +4,8 @@ package repository import ( "context" + "fmt" + "sync" "testing" "time" @@ -14,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() { s.Require().NotZero(log.ID) } +func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()}) + + const total = 16 + results := make([]bool, total) + errs := make([]error, total) + logs := make([]*service.UsageLog, total) + + var wg sync.WaitGroup + wg.Add(total) + for i := 0; i < total; i++ { + i := i + logs[i] = &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + go func() { + defer wg.Done() + results[i], errs[i] = repo.Create(ctx, logs[i]) + }() + } + wg.Wait() + + for i := 0; i < total; i++ { + require.NoError(t, errs[i]) + require.True(t, results[i]) + require.NotZero(t, logs[i].ID) + } + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count)) + require.Equal(t, total, count) +} + +func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + inserted1, err1 := repo.Create(ctx, log1) + inserted2, err2 := repo.Create(ctx, log2) + require.NoError(t, err1) + require.NoError(t, err2) + require.True(t, inserted1) + require.False(t, inserted2) + require.Equal(t, log1.ID, log2.ID) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + +func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()}) + requestID := uuid.NewString() + + const total = 8 + batch := make([]usageLogCreateRequest, 0, total) + logs := make([]*service.UsageLog, 0, total) + + for i := 0; i < total; i++ { + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + logs = append(logs, log) + batch = append(batch, usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + resultCh: make(chan usageLogCreateResult, 1), + }) + } + + repo.flushCreateBatch(integrationDB, batch) + + insertedCount := 0 + var firstID int64 + for idx, req := range batch { + res := <-req.resultCh + require.NoError(t, res.err) + if res.inserted { + insertedCount++ + } + require.NotZero(t, logs[idx].ID) + if idx == 0 { + firstID = logs[idx].ID + } else { + require.Equal(t, firstID, logs[idx].ID) + } + } + + require.Equal(t, 1, insertedCount) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + +func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + require.NoError(t, repo.CreateBestEffort(ctx, log1)) + require.NoError(t, repo.CreateBestEffort(ctx, log2)) + + require.Eventually(t, func() bool { + var count int + err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count) + return err == nil && count == 1 + }, 3*time.Second, 20*time.Millisecond) +} + +func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1) + repo.bestEffortBatchCh <- usageLogBestEffortRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()}) + + err := repo.CreateBestEffort(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.Error(t, err) + require.True(t, service.IsUsageLogCreateDropped(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + repo.createBatchCh <- usageLogCreateRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()}) + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + + go func() { + _, err := repo.createBatched(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + errCh <- err + }() + + req := <-repo.createBatchCh + require.NotNil(t, req.shared) + cancel() + + err := <-errCh + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)}) +} + +func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + req.shared.state.Store(usageLogCreateStateCanceled) + + repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req}) + + res := <-req.resultCh + require.False(t, res.inserted) + require.Error(t, res.err) + require.True(t, service.IsUsageLogCreateNotPersisted(res.err)) +} + func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 7d82b4d0..bcb23717 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) require.NoError(t, mock.ExpectationsWereMet()) } +func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + + rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}). + AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0). + AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0). + AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0) + + mock.ExpectQuery("WITH user_spend AS \\("). + WithArgs(start, end, 12). + WillReturnRows(rows) + + got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12) + require.NoError(t, err) + require.Equal(t, &usagestats.UserSpendingRankingResponse{ + Ranking: []usagestats.UserSpendingRankingItem{ + {UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900}, + {UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800}, + {UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300}, + }, + TotalActualCost: 40.0, + }, got) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go index d0e14ffd..0458902d 100644 --- a/backend/internal/repository/usage_log_repo_unit_test.go +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -3,8 +3,11 @@ package repository import ( + "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) { }) } } + +func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) { + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-batch-no-update", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 5, + TotalCost: 1.2, + ActualCost: 1.2, + CreatedAt: time.Now().UTC(), + } + prepared := prepareUsageLogInsert(log) + + query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{ + usageLogBatchKey(log.RequestID, log.APIKeyID): prepared, + }) + + require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING") + require.NotContains(t, strings.ToUpper(query), "DO UPDATE") +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 5fe7a98e..01395bcb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewUsageBillingRepository, NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index a1ce896e..d46e0624 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1635,6 +1635,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { logs := r.userLogs[userID] if len(logs) == 0 { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 46c2ccde..b6437bda 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -192,6 +192,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) + dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage) dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 91c85196..7c858fd5 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -412,6 +412,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } + // Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整) return nil } if len(rawMapping) == 0 { @@ -776,6 +777,14 @@ func (a *Account) IsInterceptWarmupEnabled() bool { return false } +func (a *Account) IsBedrock() bool { + return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey) +} + +func (a *Account) IsBedrockAPIKey() bool { + return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey +} + func (a *Account) IsOpenAI() bool { return a.Platform == PlatformOpenAI } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 472551cf..482d22b1 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account testModelID = claude.DefaultTestModel } - // For API Key accounts with model mapping, map the model + // API Key 账号测试连接时也需要应用通配符模型映射。 if account.Type == "apikey" { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } + testModelID = account.GetMappedModel(testModelID) + } + + // Bedrock accounts use a separate test path + if account.IsBedrock() { + return s.testBedrockAccountConnection(c, ctx, account, testModelID) } // Determine authentication method and API URL @@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.processClaudeStream(c, resp.Body) } +// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke +func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { + region := bedrockRuntimeRegion(account) + resolvedModelID, ok := ResolveBedrockModelID(account, testModelID) + if !ok { + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID)) + } + testModelID = resolvedModelID + + // Set SSE headers (test UI expects SSE) + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create a minimal Bedrock-compatible payload (no stream, no cache_control) + bedrockPayload := map[string]any{ + "anthropic_version": "bedrock-2023-05-31", + "messages": []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + { + "type": "text", + "text": "hi", + }, + }, + }, + }, + "max_tokens": 256, + "temperature": 1, + } + bedrockBody, _ := json.Marshal(bedrockPayload) + + // Use non-streaming endpoint (response is standard Claude JSON) + apiURL := BuildBedrockURL(region, testModelID, false) + + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + + // Sign or set auth based on account type + if account.IsBedrockAPIKey() { + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + req.Header.Set("Authorization", "Bearer "+apiKey) + } else { + signer, err := NewBedrockSignerFromAccount(account) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error())) + } + if err := signer.SignRequest(ctx, req, bedrockBody); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error())) + } + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Bedrock non-streaming response is standard Claude JSON, extract the text + var result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } + if err := json.Unmarshal(body, &result); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error())) + } + + text := "" + if len(result.Content) > 0 { + text = result.Content[0].Text + } + if text == "" { + text = "(empty response)" + } + + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + // testOpenAIAccountConnection tests an OpenAI account's connection func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { ctx := c.Request.Context() diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index e4245133..3dd931be 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -47,6 +47,7 @@ type UsageLogRepository interface { GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) + GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) diff --git a/backend/internal/service/bedrock_request.go b/backend/internal/service/bedrock_request.go new file mode 100644 index 00000000..2160c13c --- /dev/null +++ b/backend/internal/service/bedrock_request.go @@ -0,0 +1,607 @@ +package service + +import ( + "encoding/json" + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const defaultBedrockRegion = "us-east-1" + +var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."} + +// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀 +// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +func BedrockCrossRegionPrefix(region string) string { + switch { + case strings.HasPrefix(region, "us-gov"): + return "us-gov" // GovCloud 使用独立的 us-gov 前缀 + case strings.HasPrefix(region, "us-"): + return "us" + case strings.HasPrefix(region, "eu-"): + return "eu" + case region == "ap-northeast-1": + return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义) + case region == "ap-southeast-2": + return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义) + case strings.HasPrefix(region, "ap-"): + return "apac" // 其余亚太区域使用通用 apac 前缀 + case strings.HasPrefix(region, "ca-"): + return "us" // 加拿大区域使用 us 前缀的跨区域推理 + case strings.HasPrefix(region, "sa-"): + return "us" // 南美区域使用 us 前缀的跨区域推理 + default: + return "us" + } +} + +// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀 +// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1" +// 特殊值 region="global" 强制使用 global. 前缀 +func AdjustBedrockModelRegionPrefix(modelID, region string) string { + var targetPrefix string + if region == "global" { + targetPrefix = "global" + } else { + targetPrefix = BedrockCrossRegionPrefix(region) + } + + for _, p := range bedrockCrossRegionPrefixes { + if strings.HasPrefix(modelID, p) { + if p == targetPrefix+"." { + return modelID // 前缀已匹配,无需替换 + } + return targetPrefix + "." + modelID[len(p):] + } + } + + // 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改 + return modelID +} + +func bedrockRuntimeRegion(account *Account) string { + if account == nil { + return defaultBedrockRegion + } + if region := account.GetCredential("aws_region"); region != "" { + return region + } + return defaultBedrockRegion +} + +func shouldForceBedrockGlobal(account *Account) bool { + return account != nil && account.GetCredential("aws_force_global") == "true" +} + +func isRegionalBedrockModelID(modelID string) bool { + for _, prefix := range bedrockCrossRegionPrefixes { + if strings.HasPrefix(modelID, prefix) { + return true + } + } + return false +} + +func isLikelyBedrockModelID(modelID string) bool { + lower := strings.ToLower(strings.TrimSpace(modelID)) + if lower == "" { + return false + } + if strings.HasPrefix(lower, "arn:") { + return true + } + for _, prefix := range []string{ + "anthropic.", + "amazon.", + "meta.", + "mistral.", + "cohere.", + "ai21.", + "deepseek.", + "stability.", + "writer.", + "nova.", + } { + if strings.HasPrefix(lower, prefix) { + return true + } + } + return isRegionalBedrockModelID(lower) +} + +func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "", false, false + } + if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists { + return mapped, true, true + } + if isRegionalBedrockModelID(modelID) { + return modelID, true, true + } + if isLikelyBedrockModelID(modelID) { + return modelID, false, true + } + return "", false, false +} + +// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID. +// It applies account model_mapping first, then default Bedrock aliases, and finally +// adjusts Anthropic cross-region prefixes to match the account region. +func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) { + if account == nil { + return "", false + } + + mappedModel := account.GetMappedModel(requestedModel) + modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel) + if !ok { + return "", false + } + if shouldAdjustRegion { + targetRegion := bedrockRuntimeRegion(account) + if shouldForceBedrockGlobal(account) { + targetRegion = "global" + } + modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion) + } + return modelID, true +} + +// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL +// stream=true 时使用 invoke-with-response-stream 端点 +// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐) +func BuildBedrockURL(region, modelID string, stream bool) string { + if region == "" { + region = defaultBedrockRegion + } + encodedModelID := url.PathEscape(modelID) + // url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"), + // 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A + encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A") + if stream { + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID) + } + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID) +} + +// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API +// 1. 注入 anthropic_version +// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析) +// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config) +// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true}) +// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl) +func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) { + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens) +} + +// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens. +func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) { + var err error + + // 注入 anthropic_version(Bedrock 要求) + body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31") + if err != nil { + return nil, fmt.Errorf("inject anthropic_version: %w", err) + } + + // 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头) + // 1. 从客户端 anthropic-beta header 解析 + // 2. 根据请求体内容自动补齐必要的 beta token + // 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock() + if len(betaTokens) > 0 { + body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens) + if err != nil { + return nil, fmt.Errorf("inject anthropic_beta: %w", err) + } + } + + // 移除 model 字段(Bedrock 通过 URL 指定模型) + body, err = sjson.DeleteBytes(body, "model") + if err != nil { + return nil, fmt.Errorf("remove model field: %w", err) + } + + // 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段) + body, err = sjson.DeleteBytes(body, "stream") + if err != nil { + return nil, fmt.Errorf("remove stream field: %w", err) + } + + // 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message) + // 参考 litellm: _convert_output_format_to_inline_schema() + body = convertOutputFormatToInlineSchema(body) + + // 移除 output_config 字段(Bedrock Invoke 不支持) + body, err = sjson.DeleteBytes(body, "output_config") + if err != nil { + return nil, fmt.Errorf("remove output_config field: %w", err) + } + + // 移除工具定义中的 custom 字段 + // Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}, + // Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted" + body = removeCustomFieldFromTools(body) + + // 清理 cache_control 中 Bedrock 不支持的字段 + body = sanitizeBedrockCacheControl(body, modelID) + + return body, nil +} + +// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering. +func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string { + betaTokens := parseAnthropicBetaHeader(betaHeader) + betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID) + return filterBedrockBetaTokens(betaTokens) +} + +// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message +// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中 +// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema() +func convertOutputFormatToInlineSchema(body []byte) []byte { + outputFormat := gjson.GetBytes(body, "output_format") + if !outputFormat.Exists() || !outputFormat.IsObject() { + return body + } + + // 先从请求体中移除 output_format + body, _ = sjson.DeleteBytes(body, "output_format") + + schema := outputFormat.Get("schema") + if !schema.Exists() { + return body + } + + // 找到最后一条 user message + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + msgArr := messages.Array() + lastUserIdx := -1 + for i := len(msgArr) - 1; i >= 0; i-- { + if msgArr[i].Get("role").String() == "user" { + lastUserIdx = i + break + } + } + if lastUserIdx < 0 { + return body + } + + // 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组 + schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw)) + if err != nil { + return body + } + + content := msgArr[lastUserIdx].Get("content") + basePath := fmt.Sprintf("messages.%d.content", lastUserIdx) + + if content.IsArray() { + // 追加一个 text block 到 content 数组末尾 + idx := len(content.Array()) + body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text") + body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON)) + } else if content.Type == gjson.String { + // content 是纯字符串,转换为数组格式 + originalText := content.String() + body, _ = sjson.SetBytes(body, basePath, []map[string]string{ + {"type": "text", "text": originalText}, + {"type": "text", "text": string(schemaJSON)}, + }) + } + + return body +} + +// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段 +func removeCustomFieldFromTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if !tools.Exists() || !tools.IsArray() { + return body + } + var err error + for i := range tools.Array() { + body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i)) + if err != nil { + // 删除失败不影响整体流程,跳过 + continue + } + } + return body +} + +// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分 +// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式 +var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`) + +// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本 +// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h") +func isBedrockClaude45OrNewer(modelID string) bool { + lower := strings.ToLower(modelID) + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 5) +} + +// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里 +// Bedrock 不支持的字段: +// - scope:Bedrock 不支持(如 "global" 跨请求缓存) +// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除 +func sanitizeBedrockCacheControl(body []byte, modelID string) []byte { + isClaude45 := isBedrockClaude45OrNewer(modelID) + + // 清理 system 数组中的 cache_control + systemArr := gjson.GetBytes(body, "system") + if systemArr.Exists() && systemArr.IsArray() { + for i, item := range systemArr.Array() { + if !item.IsObject() { + continue + } + cc := item.Get("cache_control") + if !cc.Exists() || !cc.IsObject() { + continue + } + body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45) + } + } + + // 清理 messages 中的 cache_control + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + for mi, msg := range messages.Array() { + if !msg.IsObject() { + continue + } + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + continue + } + for ci, block := range content.Array() { + if !block.IsObject() { + continue + } + cc := block.Get("cache_control") + if !cc.Exists() || !cc.IsObject() { + continue + } + body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45) + } + } + + return body +} + +// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段 +func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte { + // Bedrock 不支持 scope(如 "global") + if cc.Get("scope").Exists() { + body, _ = sjson.DeleteBytes(body, basePath+".scope") + } + + // ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除 + ttl := cc.Get("ttl") + if ttl.Exists() { + shouldRemove := true + if isClaude45 { + v := ttl.String() + if v == "5m" || v == "1h" { + shouldRemove = false + } + } + if shouldRemove { + body, _ = sjson.DeleteBytes(body, basePath+".ttl") + } + } + + return body +} + +// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表 +func parseAnthropicBetaHeader(header string) []string { + header = strings.TrimSpace(header) + if header == "" { + return nil + } + if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") { + var parsed []any + if err := json.Unmarshal([]byte(header), &parsed); err == nil { + tokens := make([]string, 0, len(parsed)) + for _, item := range parsed { + token := strings.TrimSpace(fmt.Sprint(item)) + if token != "" { + tokens = append(tokens, token) + } + } + return tokens + } + } + var tokens []string + for _, part := range strings.Split(header, ",") { + t := strings.TrimSpace(part) + if t != "" { + tokens = append(tokens, t) + } + } + return tokens +} + +// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单 +// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json) +// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单 +var bedrockSupportedBetaTokens = map[string]bool{ + "computer-use-2025-01-24": true, + "computer-use-2025-11-24": true, + "context-1m-2025-08-07": true, + "context-management-2025-06-27": true, + "compact-2026-01-12": true, + "interleaved-thinking-2025-05-14": true, + "tool-search-tool-2025-10-19": true, + "tool-examples-2025-10-29": true, +} + +// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则 +// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头 +var bedrockBetaTokenTransforms = map[string]string{ + "advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19", +} + +// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token +// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和 +// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock() +// +// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token, +// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。 +func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string { + seen := make(map[string]bool, len(tokens)) + for _, t := range tokens { + seen[t] = true + } + + inject := func(token string) { + if !seen[token] { + tokens = append(tokens, token) + seen[token] = true + } + } + + // 检测 thinking / interleaved thinking + // 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta + if gjson.GetBytes(body, "thinking").Exists() { + inject("interleaved-thinking-2025-05-14") + } + + // 检测 computer_use 工具 + // tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() { + toolSearchUsed := false + programmaticToolCallingUsed := false + inputExamplesUsed := false + for _, tool := range tools.Array() { + toolType := tool.Get("type").String() + if strings.HasPrefix(toolType, "computer_20") { + inject("computer-use-2025-11-24") + } + if isBedrockToolSearchType(toolType) { + toolSearchUsed = true + } + if hasCodeExecutionAllowedCallers(tool) { + programmaticToolCallingUsed = true + } + if hasInputExamples(tool) { + inputExamplesUsed = true + } + } + if programmaticToolCallingUsed || inputExamplesUsed { + // programmatic tool calling 和 input examples 需要 advanced-tool-use, + // 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool + inject("advanced-tool-use-2025-11-20") + } + if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) { + // 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头, + // 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐) + if !programmaticToolCallingUsed && !inputExamplesUsed { + inject("tool-search-tool-2025-10-19") + } else { + inject("advanced-tool-use-2025-11-20") + } + } + } + + return tokens +} + +func isBedrockToolSearchType(toolType string) bool { + return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119" +} + +func hasCodeExecutionAllowedCallers(tool gjson.Result) bool { + allowedCallers := tool.Get("allowed_callers") + if containsStringInJSONArray(allowedCallers, "code_execution_20250825") { + return true + } + return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825") +} + +func hasInputExamples(tool gjson.Result) bool { + if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 { + return true + } + arr := tool.Get("function.input_examples") + return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 +} + +func containsStringInJSONArray(result gjson.Result, target string) bool { + if !result.Exists() || !result.IsArray() { + return false + } + for _, item := range result.Array() { + if item.String() == target { + return true + } + } + return false +} + +// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search +// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持 +func bedrockModelSupportsToolSearch(modelID string) bool { + lower := strings.ToLower(modelID) + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + // Haiku 不支持 tool search + if strings.Contains(lower, "haiku") { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 5) +} + +// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token +// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool) +// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等) +// 3. 自动关联 tool-examples(当 tool-search-tool 存在时) +func filterBedrockBetaTokens(tokens []string) []string { + seen := make(map[string]bool, len(tokens)) + var result []string + + for _, t := range tokens { + // 应用转换规则 + if replacement, ok := bedrockBetaTokenTransforms[t]; ok { + t = replacement + } + // 只保留白名单中的 token,且去重 + if bedrockSupportedBetaTokens[t] && !seen[t] { + result = append(result, t) + seen[t] = true + } + } + + // 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在 + if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] { + result = append(result, "tool-examples-2025-10-29") + } + + return result +} diff --git a/backend/internal/service/bedrock_request_test.go b/backend/internal/service/bedrock_request_test.go new file mode 100644 index 00000000..361cafb4 --- /dev/null +++ b/backend/internal/service/bedrock_request_test.go @@ -0,0 +1,659 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) { + input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + + // anthropic_version 应被注入 + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + // model 和 stream 应被移除 + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.False(t, gjson.GetBytes(result, "stream").Exists()) + // max_tokens 应保留 + assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int()) +} + +func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) { + t.Run("schema inlined into last user message array content", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + // schema 应内联到最后一条 user message 的 content 数组末尾 + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "text", contentArr[1].Get("type").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`) + }) + + t.Run("schema inlined into string content", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "compute this", contentArr[0].Get("text").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`) + }) + + t.Run("no schema field just removes output_format", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + }) + + t.Run("no messages just removes output_format", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + }) +} + +func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_config").Exists()) +} + +func TestRemoveCustomFieldFromTools(t *testing.T) { + input := `{ + "tools": [ + {"name":"tool1","custom":{"defer_loading":true},"description":"desc1"}, + {"name":"tool2","description":"desc2"}, + {"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"} + ] + }` + result := removeCustomFieldFromTools([]byte(input)) + + tools := gjson.GetBytes(result, "tools").Array() + require.Len(t, tools, 3) + // custom 应被移除 + assert.False(t, tools[0].Get("custom").Exists()) + // name/description 应保留 + assert.Equal(t, "tool1", tools[0].Get("name").String()) + assert.Equal(t, "desc1", tools[0].Get("description").String()) + // 没有 custom 的工具不受影响 + assert.Equal(t, "tool2", tools[1].Get("name").String()) + // 第三个工具的 custom 也应被移除 + assert.False(t, tools[2].Get("custom").Exists()) + assert.Equal(t, "tool3", tools[2].Get("name").String()) +} + +func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}]}` + result := removeCustomFieldFromTools([]byte(input)) + // 无 tools 时不改变原始数据 + assert.JSONEq(t, input, string(result)) +} + +func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}], + "messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}] + }` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + + // scope 应被移除 + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists()) + assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists()) + // type 应保留 + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}] + }` + // 旧模型(Claude 3.5)不支持 ttl + result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0") + + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}] + }` + // Claude 4.5+ 支持 "5m" 和 "1h" + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0") + + assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) + assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}] + }` + // Claude 4.5 不支持 "10m" + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0") + + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) { + input := `{ + "messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}] + }` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + + assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists()) + assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) { + input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + // 无 cache_control 时不改变原始数据 + assert.JSONEq(t, input, string(result)) +} + +func TestIsBedrockClaude45OrNewer(t *testing.T) { + tests := []struct { + modelID string + expect bool + }{ + {"us.anthropic.claude-opus-4-6-v1", true}, + {"us.anthropic.claude-sonnet-4-6", true}, + {"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true}, + {"us.anthropic.claude-opus-4-5-20251101-v1:0", true}, + {"us.anthropic.claude-haiku-4-5-20251001-v1:0", true}, + {"anthropic.claude-3-5-sonnet-20241022-v2:0", false}, + {"anthropic.claude-3-opus-20240229-v1:0", false}, + {"anthropic.claude-3-haiku-20240307-v1:0", false}, + // 未来版本应自动支持 + {"us.anthropic.claude-sonnet-5-0-v1", true}, + {"us.anthropic.claude-opus-4-7-v1", true}, + // 旧版本 + {"anthropic.claude-opus-4-1-v1", false}, + {"anthropic.claude-sonnet-4-0-v1", false}, + // 非 Claude 模型 + {"amazon.nova-pro-v1", false}, + {"meta.llama3-70b", false}, + } + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID)) + }) + } +} + +func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) { + // 模拟一个完整的 Claude Code 请求 + input := `{ + "model": "claude-opus-4-6", + "stream": true, + "max_tokens": 16384, + "output_format": {"type": "json", "schema": {"result": "string"}}, + "output_config": {"max_tokens": 100}, + "system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}], + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]} + ], + "tools": [ + {"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}}, + {"name": "read", "description": "Read file", "input_schema": {"type": "object"}} + ] + }` + + betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12" + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader) + require.NoError(t, err) + + // 基本字段 + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.False(t, gjson.GetBytes(result, "stream").Exists()) + assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int()) + + // anthropic_beta 应包含所有 beta tokens + betaArr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, betaArr, 3) + assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String()) + assert.Equal(t, "compact-2026-01-12", betaArr[2].String()) + + // output_format 应被移除,schema 内联到最后一条 user message + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + assert.False(t, gjson.GetBytes(result, "output_config").Exists()) + // content 数组:原始 text block + 内联 schema block + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "hello", contentArr[0].Get("text").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`) + + // tools 中的 custom 应被移除 + assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists()) + assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String()) + assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String()) + + // cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值 + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) + assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) + assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + t.Run("empty beta header", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + }) + + t.Run("single beta token", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 1) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + }) + + t.Run("multiple beta tokens with spaces", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + }) + + t.Run("json array beta header", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`) + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + }) +} + +func TestParseAnthropicBetaHeader(t *testing.T) { + assert.Nil(t, parseAnthropicBetaHeader("")) + assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b ")) + assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`)) +} + +func TestFilterBedrockBetaTokens(t *testing.T) { + t.Run("supported tokens pass through", func(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, tokens, result) + }) + + t.Run("unsupported tokens are filtered out", func(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + }) + + t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) { + tokens := []string{"advanced-tool-use-2025-11-20"} + result := filterBedrockBetaTokens(tokens) + assert.Contains(t, result, "tool-search-tool-2025-10-19") + // tool-examples 自动关联 + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) { + tokens := []string{"tool-search-tool-2025-10-19"} + result := filterBedrockBetaTokens(tokens) + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("no duplication when tool-examples already present", func(t *testing.T) { + tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"} + result := filterBedrockBetaTokens(tokens) + count := 0 + for _, t := range result { + if t == "tool-examples-2025-10-29" { + count++ + } + } + assert.Equal(t, 1, count) + }) + + t.Run("empty input returns nil", func(t *testing.T) { + result := filterBedrockBetaTokens(nil) + assert.Nil(t, result) + }) + + t.Run("all unsupported returns nil", func(t *testing.T) { + result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"}) + assert.Nil(t, result) + }) + + t.Run("duplicate tokens are deduplicated", func(t *testing.T) { + tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, []string{"context-1m-2025-08-07"}, result) + }) +} + +func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + t.Run("unsupported beta tokens are filtered", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", + "interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 1) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + }) + + t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", + "advanced-tool-use-2025-11-20") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String()) + assert.Equal(t, "tool-examples-2025-10-29", arr[1].String()) + }) +} + +func TestBedrockCrossRegionPrefix(t *testing.T) { + tests := []struct { + region string + expect string + }{ + // US regions + {"us-east-1", "us"}, + {"us-east-2", "us"}, + {"us-west-1", "us"}, + {"us-west-2", "us"}, + // GovCloud + {"us-gov-east-1", "us-gov"}, + {"us-gov-west-1", "us-gov"}, + // EU regions + {"eu-west-1", "eu"}, + {"eu-west-2", "eu"}, + {"eu-west-3", "eu"}, + {"eu-central-1", "eu"}, + {"eu-central-2", "eu"}, + {"eu-north-1", "eu"}, + {"eu-south-1", "eu"}, + // APAC regions + {"ap-northeast-1", "jp"}, + {"ap-northeast-2", "apac"}, + {"ap-southeast-1", "apac"}, + {"ap-southeast-2", "au"}, + {"ap-south-1", "apac"}, + // Canada / South America fallback to us + {"ca-central-1", "us"}, + {"sa-east-1", "us"}, + // Unknown defaults to us + {"me-south-1", "us"}, + } + for _, tt := range tests { + t.Run(tt.region, func(t *testing.T) { + assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region)) + }) + } +} + +func TestResolveBedrockModelID(t *testing.T) { + t.Run("default alias resolves and adjusts region", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "eu-west-1", + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5") + require.True(t, ok) + assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID) + }) + + t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "ap-southeast-2", + "model_mapping": map[string]any{ + "claude-*": "claude-opus-4-6", + }, + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking") + require.True(t, ok) + assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID) + }) + + t.Run("force global rewrites anthropic regional model id", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + "aws_force_global": "true", + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6", + }, + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6") + require.True(t, ok) + assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID) + }) + + t.Run("direct bedrock model id passes through", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0") + require.True(t, ok) + assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID) + }) + + t.Run("unsupported alias returns false", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + _, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022") + assert.False(t, ok) + }) +} + +func TestAutoInjectBedrockBetaTokens(t *testing.T) { + t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "interleaved-thinking-2025-05-14") + }) + + t.Run("no duplicate when already present", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1") + count := 0 + for _, t := range result { + if t == "interleaved-thinking-2025-05-14" { + count++ + } + } + assert.Equal(t, 1, count) + }) + + t.Run("inject computer-use when computer tool present", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "computer-use-2025-11-24") + }) + + t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject advanced-tool-use for input examples", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6") + // 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换 + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.NotContains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6") + // 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool) + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0") + assert.NotContains(t, result, "advanced-tool-use-2025-11-20") + assert.NotContains(t, result, "tool-search-tool-2025-10-19") + }) + + t.Run("no injection for regular tools", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Empty(t, result) + }) + + t.Run("no injection when no features detected", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Empty(t, result) + }) + + t.Run("preserves existing tokens", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`) + existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"} + result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "context-1m-2025-08-07") + assert.Contains(t, result, "compact-2026-01-12") + assert.Contains(t, result, "interleaved-thinking-2025-05-14") + }) +} + +func TestResolveBedrockBetaTokens(t *testing.T) { + t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) + result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + }) +} + +func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) { + t.Run("thinking in body auto-injects beta without header", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + found := false + for _, v := range arr { + if v.String() == "interleaved-thinking-2025-05-14" { + found = true + } + } + assert.True(t, found, "interleaved-thinking should be auto-injected") + }) + + t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + names := make([]string, len(arr)) + for i, v := range arr { + names[i] = v.String() + } + assert.Contains(t, names, "context-1m-2025-08-07") + assert.Contains(t, names, "interleaved-thinking-2025-05-14") + }) +} + +func TestAdjustBedrockModelRegionPrefix(t *testing.T) { + tests := []struct { + name string + modelID string + region string + expect string + }{ + // US region — no change needed + {"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"}, + // EU region — replace us → eu + {"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"}, + {"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"}, + // APAC region — jp and au have dedicated prefixes per AWS docs + {"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"}, + {"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"}, + {"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"}, + // eu → us (user manually set eu prefix, moved to us region) + {"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"}, + // global prefix — replace to match region + {"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"}, + // No known prefix — leave unchanged + {"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"}, + // GovCloud — uses independent us-gov prefix + {"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"}, + {"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"}, + // Force global (special region value) + {"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"}, + {"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region)) + }) + } +} diff --git a/backend/internal/service/bedrock_signer.go b/backend/internal/service/bedrock_signer.go new file mode 100644 index 00000000..e7000b4d --- /dev/null +++ b/backend/internal/service/bedrock_signer.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" +) + +// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名 +type BedrockSigner struct { + credentials aws.Credentials + region string + signer *v4.Signer +} + +// NewBedrockSigner 创建 BedrockSigner +func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner { + return &BedrockSigner{ + credentials: aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + }, + region: region, + signer: v4.NewSigner(), + } +} + +// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner +func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) { + accessKeyID := account.GetCredential("aws_access_key_id") + if accessKeyID == "" { + return nil, fmt.Errorf("aws_access_key_id not found in credentials") + } + secretAccessKey := account.GetCredential("aws_secret_access_key") + if secretAccessKey == "" { + return nil, fmt.Errorf("aws_secret_access_key not found in credentials") + } + region := account.GetCredential("aws_region") + if region == "" { + region = defaultBedrockRegion + } + sessionToken := account.GetCredential("aws_session_token") // 可选 + + return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil +} + +// SignRequest 对 HTTP 请求进行 SigV4 签名 +// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。 +// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header, +// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤, +// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。 +func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error { + payloadHash := sha256Hash(body) + return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now()) +} + +func sha256Hash(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} diff --git a/backend/internal/service/bedrock_signer_test.go b/backend/internal/service/bedrock_signer_test.go new file mode 100644 index 00000000..641e9341 --- /dev/null +++ b/backend/internal/service/bedrock_signer_test.go @@ -0,0 +1,35 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_access_key_id": "test-akid", + "aws_secret_access_key": "test-secret", + }, + } + + signer, err := NewBedrockSignerFromAccount(account) + require.NoError(t, err) + require.NotNil(t, signer) + assert.Equal(t, defaultBedrockRegion, signer.region) +} + +func TestFilterBetaTokens(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"} + filterSet := map[string]struct{}{ + "tool-search-tool-2025-10-19": {}, + } + + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet)) + assert.Equal(t, tokens, filterBetaTokens(tokens, nil)) + assert.Nil(t, filterBetaTokens(nil, filterSet)) +} diff --git a/backend/internal/service/bedrock_stream.go b/backend/internal/service/bedrock_stream.go new file mode 100644 index 00000000..98196d27 --- /dev/null +++ b/backend/internal/service/bedrock_stream.go @@ -0,0 +1,414 @@ +package service + +import ( + "bufio" + "context" + "encoding/base64" + "errors" + "fmt" + "hash/crc32" + "io" + "net/http" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应 +// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的 +// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。 +func (s *GatewayService) handleBedrockStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + + // Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。 + // 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4) + // 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。 + // 我们使用 EventStream decoder 来正确解析。 + decoder := newBedrockEventStreamDecoder(resp.Body) + + type decodeEvent struct { + payload []byte + err error + } + events := make(chan decodeEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev decodeEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt atomic.Int64 + lastReadAt.Store(time.Now().UnixNano()) + + go func() { + defer close(events) + for { + payload, err := decoder.Decode() + if err != nil { + if err == io.EOF { + return + } + _ = sendEvent(decodeEvent{err: err}) + return + } + lastReadAt.Store(time.Now().UnixNano()) + if !sendEvent(decodeEvent{payload: payload}) { + return + } + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err) + } + + // payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据) + sseData := extractBedrockChunkData(ev.payload) + if sseData == nil { + continue + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 同时移除该字段避免透传给客户端 + sseData = transformBedrockInvocationMetrics(sseData) + + // 解析 SSE 事件数据提取 usage + s.parseSSEUsagePassthrough(string(sseData), usage) + + // 确定 SSE event type + eventType := gjson.GetBytes(sseData, "type").String() + + // 写入标准 SSE 格式 + if !clientDisconnected { + var writeErr error + if eventType != "" { + _, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData) + } else { + _, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData) + } + if writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, lastReadAt.Load()) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据 +// Bedrock payload 格式:{"bytes":""} +func extractBedrockChunkData(payload []byte) []byte { + b64 := gjson.GetBytes(payload, "bytes").String() + if b64 == "" { + return nil + } + decoded, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil + } + return decoded +} + +// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics +// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。 +// +// Bedrock Invoke 返回的 message_delta 事件可能包含: +// +// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}} +// +// 转换为: +// +// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}} +func transformBedrockInvocationMetrics(data []byte) []byte { + metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics") + if !metrics.Exists() || !metrics.IsObject() { + return data + } + + // 移除 Bedrock 特有字段 + data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics") + + // 如果已有标准 usage 字段,不覆盖 + if gjson.GetBytes(data, "usage").Exists() { + return data + } + + // 转换 camelCase → snake_case 写入 usage + inputTokens := metrics.Get("inputTokenCount") + outputTokens := metrics.Get("outputTokenCount") + if inputTokens.Exists() { + data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int()) + } + if outputTokens.Exists() { + data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int()) + } + + return data +} + +// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧 +// EventStream 帧格式: +// +// [total_byte_length: 4 bytes] +// [headers_byte_length: 4 bytes] +// [prelude_crc: 4 bytes] +// [headers: variable] +// [payload: variable] +// [message_crc: 4 bytes] +type bedrockEventStreamDecoder struct { + reader *bufio.Reader +} + +func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder { + return &bedrockEventStreamDecoder{ + reader: bufio.NewReaderSize(r, 64*1024), + } +} + +// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload +func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) { + for { + // 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes + prelude := make([]byte, 12) + if _, err := io.ReadFull(d.reader, prelude); err != nil { + return nil, err + } + + // 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE) + preludeCRC := bedrockReadUint32(prelude[8:12]) + if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC { + return nil, fmt.Errorf("eventstream prelude CRC mismatch") + } + + totalLength := bedrockReadUint32(prelude[0:4]) + headersLength := bedrockReadUint32(prelude[4:8]) + + if totalLength < 16 { // minimum: 12 prelude + 4 message_crc + return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength) + } + + // 读取 headers + payload + message_crc + remaining := int(totalLength) - 12 + if remaining <= 0 { + continue + } + data := make([]byte, remaining) + if _, err := io.ReadFull(d.reader, data); err != nil { + return nil, err + } + + // 验证 message CRC(覆盖 prelude + headers + payload) + messageCRC := bedrockReadUint32(data[len(data)-4:]) + h := crc32.New(crc32IEEETable) + _, _ = h.Write(prelude) + _, _ = h.Write(data[:len(data)-4]) + if h.Sum32() != messageCRC { + return nil, fmt.Errorf("eventstream message CRC mismatch") + } + + // 解析 headers + headers := data[:headersLength] + payload := data[headersLength : len(data)-4] // 去掉 message_crc + + // 从 headers 中提取 :event-type + eventType := extractEventStreamHeaderValue(headers, ":event-type") + + // 只处理 chunk 事件 + if eventType == "chunk" { + // payload 是完整的 JSON,包含 bytes 字段 + return payload, nil + } + + // 检查异常事件 + exceptionType := extractEventStreamHeaderValue(headers, ":exception-type") + if exceptionType != "" { + return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload)) + } + + messageType := extractEventStreamHeaderValue(headers, ":message-type") + if messageType == "exception" || messageType == "error" { + return nil, fmt.Errorf("bedrock error: %s", string(payload)) + } + + // 跳过其他事件类型(如 initial-response) + } +} + +// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值 +// EventStream header 格式: +// +// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable] +// +// value_type = 7 表示 string 类型,前 2 bytes 为长度 +func extractEventStreamHeaderValue(headers []byte, targetName string) string { + pos := 0 + for pos < len(headers) { + if pos >= len(headers) { + break + } + nameLen := int(headers[pos]) + pos++ + if pos+nameLen > len(headers) { + break + } + name := string(headers[pos : pos+nameLen]) + pos += nameLen + + if pos >= len(headers) { + break + } + valueType := headers[pos] + pos++ + + switch valueType { + case 7: // string + if pos+2 > len(headers) { + return "" + } + valueLen := int(bedrockReadUint16(headers[pos : pos+2])) + pos += 2 + if pos+valueLen > len(headers) { + return "" + } + value := string(headers[pos : pos+valueLen]) + pos += valueLen + if name == targetName { + return value + } + case 0: // bool true + if name == targetName { + return "true" + } + case 1: // bool false + if name == targetName { + return "false" + } + case 2: // byte + pos++ + if name == targetName { + return "" + } + case 3: // short + pos += 2 + if name == targetName { + return "" + } + case 4: // int + pos += 4 + if name == targetName { + return "" + } + case 5: // long + pos += 8 + if name == targetName { + return "" + } + case 6: // bytes + if pos+2 > len(headers) { + return "" + } + valueLen := int(bedrockReadUint16(headers[pos : pos+2])) + pos += 2 + valueLen + case 8: // timestamp + pos += 8 + case 9: // uuid + pos += 16 + default: + return "" // 未知类型,无法继续解析 + } + } + return "" +} + +// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream. +var crc32IEEETable = crc32.MakeTable(crc32.IEEE) + +func bedrockReadUint32(b []byte) uint32 { + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) +} + +func bedrockReadUint16(b []byte) uint16 { + return uint16(b[0])<<8 | uint16(b[1]) +} diff --git a/backend/internal/service/bedrock_stream_test.go b/backend/internal/service/bedrock_stream_test.go new file mode 100644 index 00000000..3d066137 --- /dev/null +++ b/backend/internal/service/bedrock_stream_test.go @@ -0,0 +1,261 @@ +package service + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "hash/crc32" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestExtractBedrockChunkData(t *testing.T) { + t.Run("valid base64 payload", func(t *testing.T) { + original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}` + b64 := base64.StdEncoding.EncodeToString([]byte(original)) + payload := []byte(`{"bytes":"` + b64 + `"}`) + + result := extractBedrockChunkData(payload) + require.NotNil(t, result) + assert.JSONEq(t, original, string(result)) + }) + + t.Run("empty bytes field", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"bytes":""}`)) + assert.Nil(t, result) + }) + + t.Run("no bytes field", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"other":"value"}`)) + assert.Nil(t, result) + }) + + t.Run("invalid base64", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`)) + assert.Nil(t, result) + }) +} + +func TestTransformBedrockInvocationMetrics(t *testing.T) { + t.Run("converts metrics to usage", func(t *testing.T) { + input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}` + result := transformBedrockInvocationMetrics([]byte(input)) + + // amazon-bedrock-invocationMetrics should be removed + assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists()) + // usage should be set + assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int()) + assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int()) + // original fields preserved + assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String()) + assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String()) + }) + + t.Run("no metrics present", func(t *testing.T) { + input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}` + result := transformBedrockInvocationMetrics([]byte(input)) + assert.JSONEq(t, input, string(result)) + }) + + t.Run("does not overwrite existing usage", func(t *testing.T) { + input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}` + result := transformBedrockInvocationMetrics([]byte(input)) + + // metrics removed but existing usage preserved + assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists()) + assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int()) + }) +} + +func TestExtractEventStreamHeaderValue(t *testing.T) { + // Build a header with :event-type = "chunk" (string type = 7) + buildStringHeader := func(name, value string) []byte { + var buf bytes.Buffer + // name length (1 byte) + _ = buf.WriteByte(byte(len(name))) + // name + _, _ = buf.WriteString(name) + // value type (7 = string) + _ = buf.WriteByte(7) + // value length (2 bytes, big-endian) + _ = binary.Write(&buf, binary.BigEndian, uint16(len(value))) + // value + _, _ = buf.WriteString(value) + return buf.Bytes() + } + + t.Run("find string header", func(t *testing.T) { + headers := buildStringHeader(":event-type", "chunk") + assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) + }) + + t.Run("header not found", func(t *testing.T) { + headers := buildStringHeader(":event-type", "chunk") + assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type")) + }) + + t.Run("multiple headers", func(t *testing.T) { + var buf bytes.Buffer + _, _ = buf.Write(buildStringHeader(":content-type", "application/json")) + _, _ = buf.Write(buildStringHeader(":event-type", "chunk")) + _, _ = buf.Write(buildStringHeader(":message-type", "event")) + + headers := buf.Bytes() + assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) + assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type")) + assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type")) + }) + + t.Run("empty headers", func(t *testing.T) { + assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type")) + }) +} + +func TestBedrockEventStreamDecoder(t *testing.T) { + crc32IeeeTab := crc32.MakeTable(crc32.IEEE) + + // Build a valid EventStream frame with correct CRC32/IEEE checksums. + buildFrame := func(eventType string, payload []byte) []byte { + // Build headers + var headersBuf bytes.Buffer + // :event-type header + _ = headersBuf.WriteByte(byte(len(":event-type"))) + _, _ = headersBuf.WriteString(":event-type") + _ = headersBuf.WriteByte(7) // string type + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType))) + _, _ = headersBuf.WriteString(eventType) + // :message-type header + _ = headersBuf.WriteByte(byte(len(":message-type"))) + _, _ = headersBuf.WriteString(":message-type") + _ = headersBuf.WriteByte(7) + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event"))) + _, _ = headersBuf.WriteString("event") + + headers := headersBuf.Bytes() + headersLen := uint32(len(headers)) + // total = 12 (prelude) + headers + payload + 4 (message_crc) + totalLen := uint32(12 + len(headers) + len(payload) + 4) + + // Prelude: total_length(4) + headers_length(4) + var preludeBuf bytes.Buffer + _ = binary.Write(&preludeBuf, binary.BigEndian, totalLen) + _ = binary.Write(&preludeBuf, binary.BigEndian, headersLen) + preludeBytes := preludeBuf.Bytes() + preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab) + + // Build frame: prelude + prelude_crc + headers + payload + var frame bytes.Buffer + _, _ = frame.Write(preludeBytes) + _ = binary.Write(&frame, binary.BigEndian, preludeCRC) + _, _ = frame.Write(headers) + _, _ = frame.Write(payload) + + // Message CRC covers everything before itself + messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab) + _ = binary.Write(&frame, binary.BigEndian, messageCRC) + return frame.Bytes() + } + + t.Run("decode chunk event", func(t *testing.T) { + payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test") + frame := buildFrame("chunk", payload) + + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + result, err := decoder.Decode() + require.NoError(t, err) + assert.Equal(t, payload, result) + }) + + t.Run("skip non-chunk events", func(t *testing.T) { + // Write initial-response followed by chunk + var buf bytes.Buffer + _, _ = buf.Write(buildFrame("initial-response", []byte(`{}`))) + chunkPayload := []byte(`{"bytes":"aGVsbG8="}`) + _, _ = buf.Write(buildFrame("chunk", chunkPayload)) + + decoder := newBedrockEventStreamDecoder(&buf) + result, err := decoder.Decode() + require.NoError(t, err) + assert.Equal(t, chunkPayload, result) + }) + + t.Run("EOF on empty input", func(t *testing.T) { + decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil)) + _, err := decoder.Decode() + assert.Equal(t, io.EOF, err) + }) + + t.Run("corrupted prelude CRC", func(t *testing.T) { + frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`)) + // Corrupt the prelude CRC (bytes 8-11) + frame[8] ^= 0xFF + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "prelude CRC mismatch") + }) + + t.Run("corrupted message CRC", func(t *testing.T) { + frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`)) + // Corrupt the message CRC (last 4 bytes) + frame[len(frame)-1] ^= 0xFF + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "message CRC mismatch") + }) + + t.Run("castagnoli encoded frame is rejected", func(t *testing.T) { + castagnoliTab := crc32.MakeTable(crc32.Castagnoli) + payload := []byte(`{"bytes":"dGVzdA=="}`) + + var headersBuf bytes.Buffer + _ = headersBuf.WriteByte(byte(len(":event-type"))) + _, _ = headersBuf.WriteString(":event-type") + _ = headersBuf.WriteByte(7) + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk"))) + _, _ = headersBuf.WriteString("chunk") + + headers := headersBuf.Bytes() + headersLen := uint32(len(headers)) + totalLen := uint32(12 + len(headers) + len(payload) + 4) + + var preludeBuf bytes.Buffer + _ = binary.Write(&preludeBuf, binary.BigEndian, totalLen) + _ = binary.Write(&preludeBuf, binary.BigEndian, headersLen) + preludeBytes := preludeBuf.Bytes() + + var frame bytes.Buffer + _, _ = frame.Write(preludeBytes) + _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab)) + _, _ = frame.Write(headers) + _, _ = frame.Write(payload) + _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab)) + + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes())) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "prelude CRC mismatch") + }) +} + +func TestBuildBedrockURL(t *testing.T) { + t.Run("stream URL with colon in model ID", func(t *testing.T) { + url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true) + assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url) + }) + + t.Run("non-stream URL with colon in model ID", func(t *testing.T) { + url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false) + assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url) + }) + + t.Run("model ID without colon", func(t *testing.T) { + url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true) + assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url) + }) +} diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index a67f8532..b58a1ea9 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -35,6 +35,7 @@ type DashboardAggregationRepository interface { UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error CleanupUsageLogs(ctx context.Context, cutoff time.Time) error + CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error } @@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) + dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays) aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { @@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, if usageErr != nil { logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } - if aggErr == nil && usageErr == nil { + dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff) + if dedupErr != nil { + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr) + } + if aggErr == nil && usageErr == nil && dedupErr == nil { s.lastRetentionCleanup.Store(now) } } diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index a7058985..fbb671bb 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -12,12 +12,18 @@ import ( type dashboardAggregationRepoTestStub struct { aggregateCalls int + recomputeCalls int + cleanupUsageCalls int + cleanupDedupCalls int + ensurePartitionCalls int lastStart time.Time lastEnd time.Time watermark time.Time aggregateErr error cleanupAggregatesErr error cleanupUsageErr error + cleanupDedupErr error + ensurePartitionErr error } func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s } func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.AggregateRange(ctx, start, end) } @@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context } func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + s.cleanupUsageCalls++ return s.cleanupUsageErr } +func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + s.cleanupDedupCalls++ + return s.cleanupDedupErr +} + func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { - return nil + s.ensurePartitionCalls++ + return s.ensurePartitionErr } func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) { @@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupUsageCalls) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) + + require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + UsageBillingDedupDays: 2, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.runScheduledAggregation() + + require.Equal(t, 1, repo.ensurePartitionCalls) + require.Equal(t, 1, repo.aggregateCalls) } func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) { diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 2af43386..63cad243 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } +func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit) + if err != nil { + return nil, fmt.Errorf("get user spending ranking: %w", err) + } + return ranking, nil +} + func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index 59b83e66..2a7f47b6 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut return nil } +func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 304c09f4..ad64b467 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -29,10 +29,12 @@ const ( // Account type constants const ( - AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 - AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) + AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) ) // Redeem type constants diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 5dcda1de..789cbab8 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, - deferredService: &DeferredService{}, - billingCacheService: nil, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, } account := &Account{ @@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, } account := &Account{ @@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf require.Equal(t, 5, result.usage.OutputTokens) } +func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) +} + func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi _ = pr.Close() <-done - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after timeout") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 9, result.usage.InputTokens) @@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete") require.NotNil(t, result) require.True(t, result.clientDisconnect) } @@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after disconnect") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 8, result.usage.InputTokens) diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go new file mode 100644 index 00000000..4e7e545a --- /dev/null +++ b/backend/internal/service/gateway_record_usage_test.go @@ -0,0 +1,371 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + return NewGatewayService( + nil, + nil, + usageRepo, + nil, + userRepo, + subRepo, + nil, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + nil, + &DeferredService{}, + nil, + nil, + nil, + nil, + nil, + ) +} + +func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.usageBillingRepo = billingRepo + return svc +} + +type openAIRecordUsageBestEffortLogRepoStub struct { + UsageLogRepository + + bestEffortErr error + createErr error + bestEffortCalls int + createCalls int + lastLog *UsageLog + lastCtxErr error +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error { + s.bestEffortCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return s.bestEffortErr +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.createCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return false, s.createErr +} + +func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 501, + Quota: 100, + }, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`)) + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_hash", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_fallback", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_not_persisted", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 503, + Quota: 100, + }, + User: &User{ID: 603}, + Account: &Account{ID: 703}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{ + Result: &ForwardResult{ + RequestID: "gateway_long_context_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 12, + OutputTokens: 8, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 502, + Quota: 100, + }, + User: &User{ID: 602}, + Account: &Account{ID: 702}, + LongContextThreshold: 200000, + LongContextMultiplier: 2, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 504}, + User: &User{ID: 604}, + Account: &Account{ID: 704}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123") + ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "upstream-volatile-456", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 506}, + User: &User{ID: 606}, + Account: &Account{ID: 706}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 507}, + User: &User{ID: 607}, + Account: &Account{ID: 707}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) { + usageRepo := &openAIRecordUsageBestEffortLogRepoStub{ + bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")), + } + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_drop_usage_log", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 508}, + User: &User{ID: 608}, + Account: &Account{ID: 708}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.bestEffortCalls) + require.Equal(t, 0, usageRepo.createCalls) +} + +func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_billing_fail", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 505}, + User: &User{ID: 605}, + Account: &Account{ID: 705}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8a433a36..c86b6964 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -50,6 +50,7 @@ const ( defaultUserGroupRateCacheTTL = 30 * time.Second defaultModelsListCacheTTL = 15 * time.Second + postUsageBillingTimeout = 15 * time.Second ) const ( @@ -106,6 +107,36 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +func openAIStreamEventIsTerminal(data string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + switch gjson.Get(trimmed, "type").String() { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func anthropicStreamEventIsTerminal(eventName, data string) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { + return true + } + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + return gjson.Get(trimmed, "type").String() == "message_stop" +} + func cloneStringSlice(src []string) []string { if len(src) == 0 { return nil @@ -504,6 +535,7 @@ type GatewayService struct { accountRepo AccountRepository groupRepo GroupRepository usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository userRepo UserRepository userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository @@ -537,6 +569,7 @@ func NewGatewayService( accountRepo AccountRepository, groupRepo GroupRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -563,6 +596,7 @@ func NewGatewayService( accountRepo: accountRepo, groupRepo: groupRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, @@ -3336,6 +3370,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo if account.Platform == PlatformSora { return s.isSoraModelSupportedByAccount(account, requestedModel) } + if account.IsBedrock() { + _, ok := ResolveBedrockModelID(account, requestedModel) + return ok + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -3493,6 +3531,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( return "", "", errors.New("api_key not found in credentials") } return apiKey, "apikey", nil + case AccountTypeBedrock: + return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token + case AccountTypeBedrockAPIKey: + return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理 default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } @@ -3948,6 +3990,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) } + if account != nil && account.IsBedrock() { + return s.forwardBedrock(ctx, c, account, parsed, startTime) + } + // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { @@ -4049,7 +4095,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4127,7 +4175,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx() if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4159,7 +4209,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) + retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx2() if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -4226,7 +4278,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A rectifiedBody, applied := RectifyThinkingBudget(body) if applied && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) - budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseBudgetRetryCtx() if buildErr == nil { budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4498,7 +4552,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4774,6 +4830,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( usage := &ClaudeUsage{} var firstTokenMs *int clientDisconnected := false + sawTerminalEvent := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -4836,17 +4893,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 flusher.Flush() } + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) @@ -4858,11 +4918,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( line := ev.line if data, ok := extractAnthropicSSEDataLine(line); ok { trimmed := strings.TrimSpace(data) + if anthropicStreamEventIsTerminal("", trimmed) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsagePassthrough(data, usage) + } else { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { + sawTerminalEvent = true + } } if !clientDisconnected { @@ -4884,8 +4952,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( continue } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) if s.rateLimitService != nil { @@ -5068,6 +5135,366 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, } } +// forwardBedrock 转发请求到 AWS Bedrock +func (s *GatewayService) forwardBedrock( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *ParsedRequest, + startTime time.Time, +) (*ForwardResult, error) { + reqModel := parsed.Model + reqStream := parsed.Stream + body := parsed.Body + + region := bedrockRuntimeRegion(account) + mappedModel, ok := ResolveBedrockModelID(account, reqModel) + if !ok { + return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel) + } + if mappedModel != reqModel { + logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + + betaHeader := "" + if c != nil && c.Request != nil { + betaHeader = c.GetHeader("anthropic-beta") + } + + // 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control) + betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel) + if err != nil { + return nil, err + } + + bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) + if err != nil { + return nil, fmt.Errorf("prepare bedrock request body: %w", err) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v", + account.ID, account.Name, reqModel, mappedModel, reqStream) + + // 根据账号类型选择认证方式 + var signer *BedrockSigner + var bedrockAPIKey string + if account.IsBedrockAPIKey() { + bedrockAPIKey = account.GetCredential("api_key") + if bedrockAPIKey == "" { + return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials") + } + } else { + signer, err = NewBedrockSignerFromAccount(account) + if err != nil { + return nil, fmt.Errorf("create bedrock signer: %w", err) + } + } + + // 执行上游请求(含重试) + resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + // 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id, + // 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。 + if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" { + resp.Header.Set("x-request-id", awsReqID) + } + + // 错误/failover 处理 + if resp.StatusCode >= 400 { + return s.handleBedrockUpstreamErrors(ctx, resp, c, account) + } + + // 响应处理 + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-amzn-requestid"), + Usage: *usage, + Model: reqModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑) +func (s *GatewayService) executeBedrockUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, + apiKey string, + proxyURL string, +) (*http.Response, error) { + var resp *http.Response + var err error + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + var upstreamReq *http.Request + if account.IsBedrockAPIKey() { + upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey) + } else { + upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer) + } + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + return resp, nil +} + +// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应) +func (s *GatewayService) handleBedrockUpstreamErrors( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ForwardResult, error) { + // retry exhausted + failover + if s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s", + account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // non-retryable failover + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + // other errors + return s.handleErrorResponse(ctx, resp, c, account) +} + +// buildUpstreamRequestBedrock 构建 Bedrock 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrock( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // SigV4 签名 + if err := signer.SignRequest(ctx, req, body); err != nil { + return nil, fmt.Errorf("sign bedrock request: %w", err) + } + + return req, nil +} + +// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrockAPIKey( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + apiKey string, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + return req, nil +} + +// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应 +// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容 +func (s *GatewayService) handleBedrockNonStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 并移除该字段避免透传给客户端 + body = transformBedrockInvocationMetrics(body) + + usage := parseClaudeUsageFromResponseBody(body) + + c.Header("Content-Type", "application/json") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + c.Data(resp.StatusCode, "application/json", body) + return usage, nil +} + func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL @@ -5481,6 +5908,76 @@ func containsBetaToken(header, token string) bool { return false } +func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { + if len(tokens) == 0 || len(filterSet) == 0 { + return tokens + } + kept := make([]string, 0, len(tokens)) + for _, token := range tokens { + if _, filtered := filterSet[token]; !filtered { + kept = append(kept, token) + } + } + return kept +} + +func (s *GatewayService) resolveBedrockBetaTokensForRequest( + ctx context.Context, + account *Account, + betaHeader string, + body []byte, + modelID string, +) ([]string, error) { + // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + + // 2. 解析 header + body 自动注入 + Bedrock 转换/过滤 + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + + // 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。 + // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, + // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → + // 如果不做此检查,block 规则会被绕过。 + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + return nil, blockErr + } + + return filterBetaTokens(betaTokens, policy.filterSet), nil +} + +// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 +// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { + if s.settingService == nil || len(tokens) == 0 { + return nil + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return nil + } + isOAuth := account.IsOAuth() + tokenSet := buildBetaTokenSet(tokens) + for _, rule := range settings.Rules { + if rule.Action != BetaPolicyActionBlock { + continue + } + if !betaPolicyScopeMatches(rule.Scope, isOAuth) { + continue + } + if _, present := tokenSet[rule.BetaToken]; present { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + return &BetaBlockedError{Message: msg} + } + } + return nil +} + func buildBetaTokenSet(tokens []string) map[string]struct{} { m := make(map[string]struct{}, len(tokens)) for _, t := range tokens { @@ -6027,6 +6524,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + sawTerminalEvent := false pendingEventLines := make([]string, 0, 4) @@ -6057,6 +6555,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if dataLine == "[DONE]" { + sawTerminalEvent = true block := "" if eventName != "" { block = "event: " + eventName + "\n" @@ -6123,6 +6622,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } usagePatch := s.extractSSEUsagePatch(event) + if anthropicStreamEventIsTerminal(eventName, dataLine) { + sawTerminalEvent = true + } if !eventChanged { block := "" if eventName != "" { @@ -6156,18 +6658,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http case ev, ok := <-events: if !ok { // 上游完成,返回结果 + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { @@ -6226,9 +6732,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } if clientDisconnected { - // 客户端已断开,上游也超时了,返回已收集的 usage - logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -6590,15 +7094,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -6607,6 +7112,14 @@ type APIKeyQuotaUpdater interface { UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } +type apiKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) +} + +type usageLogBestEffortWriter interface { + CreateBestEffort(ctx context.Context, log *UsageLog) error +} + // postUsageBillingParams 统一扣费所需的参数 type postUsageBillingParams struct { Cost *CostBreakdown @@ -6614,6 +7127,7 @@ type postUsageBillingParams struct { APIKey *APIKey Account *Account Subscription *UserSubscription + RequestPayloadHash string IsSubscriptionBill bool AccountRateMultiplier float64 APIKeyService APIKeyQuotaUpdater @@ -6625,19 +7139,22 @@ type postUsageBillingParams struct { // - API Key 限速用量更新 // - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + cost := p.Cost // 1. 订阅 / 余额扣费 if p.IsSubscriptionBill { if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { - if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { + if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) } deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) @@ -6646,31 +7163,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill // 2. API Key 配额 if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { - if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } // 3. API Key 限速用量 if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { - if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } - deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost) } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { accountCost := cost.TotalCost * p.AccountRateMultiplier - if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { + if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) } } - // 5. 更新账号最近使用时间 + finalizePostUsageBilling(p, deps) +} + +func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + return "generated:" + generateRequestID() +} + +func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { + if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { + return payloadHash + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { + if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { + return nil + } + + cmd := &UsageBillingCommand{ + RequestID: requestID, + APIKeyID: p.APIKey.ID, + UserID: p.User.ID, + AccountID: p.Account.ID, + AccountType: p.Account.Type, + RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), + } + if usageLog != nil { + cmd.Model = usageLog.Model + cmd.BillingType = usageLog.BillingType + cmd.InputTokens = usageLog.InputTokens + cmd.OutputTokens = usageLog.OutputTokens + cmd.CacheCreationTokens = usageLog.CacheCreationTokens + cmd.CacheReadTokens = usageLog.CacheReadTokens + cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } + if usageLog.ServiceTier != nil { + cmd.ServiceTier = *usageLog.ServiceTier + } + if usageLog.ReasoningEffort != nil { + cmd.ReasoningEffort = *usageLog.ReasoningEffort + } + if usageLog.SubscriptionID != nil { + cmd.SubscriptionID = usageLog.SubscriptionID + } + } + + if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { + cmd.SubscriptionID = &p.Subscription.ID + cmd.SubscriptionCost = p.Cost.TotalCost + } else if p.Cost.ActualCost > 0 { + cmd.BalanceCost = p.Cost.ActualCost + } + + if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + cmd.APIKeyQuotaCost = p.Cost.ActualCost + } + if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + cmd.APIKeyRateLimitCost = p.Cost.ActualCost + } + if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { + cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + } + + cmd.Normalize() + return cmd +} + +func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { + if p == nil || deps == nil { + return false, nil + } + + cmd := buildUsageBillingCommand(requestID, usageLog, p) + if cmd == nil || cmd.RequestID == "" || repo == nil { + postUsageBilling(ctx, p, deps) + return true, nil + } + + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + result, err := repo.Apply(billingCtx, cmd) + if err != nil { + return false, err + } + + if result == nil || !result.Applied { + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + return false, nil + } + + if result.APIKeyQuotaExhausted { + if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { + invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + } + } + + finalizePostUsageBilling(p, deps) + return true, nil +} + +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { + if p == nil || p.Cost == nil || deps == nil { + return + } + + if p.IsSubscriptionBill { + if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + } + } else if p.Cost.ActualCost > 0 && p.User != nil { + deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) + } + + if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) + } + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) } +func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, postUsageBillingTimeout) +} + +func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if !stream { + return ctx, func() {} + } + if ctx == nil { + return context.Background(), func() {} + } + return context.WithoutCancel(ctx), func() {} +} + // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) type billingDeps struct { accountRepo AccountRepository @@ -6690,6 +7363,31 @@ func (s *GatewayService) billingDeps() *billingDeps { } } +func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { + if repo == nil || usageLog == nil { + return + } + usageCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + if writer, ok := repo.(usageLogBestEffortWriter); ok { + if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if IsUsageLogCreateDropped(err) { + return + } + if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { + logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) + } + } + return + } + + if _, err := repo.Create(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + } +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result @@ -6791,11 +7489,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu mediaType = &result.MediaType } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -6840,33 +7539,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } @@ -6877,13 +7575,14 @@ type RecordUsageLongContextInput struct { APIKey *APIKey User *User Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService *APIKeyService // API Key 配额服务(可选) + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -6966,11 +7665,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -7014,33 +7714,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } @@ -7064,6 +7763,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) } + // Bedrock 不支持 count_tokens 端点 + if account != nil && account.IsBedrock() { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock") + return nil + } + body := parsed.Body reqModel := parsed.Model diff --git a/backend/internal/service/gateway_service_bedrock_beta_test.go b/backend/internal/service/gateway_service_bedrock_beta_test.go new file mode 100644 index 00000000..8920ee08 --- /dev/null +++ b/backend/internal/service/gateway_service_bedrock_beta_test.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +type betaPolicySettingRepoStub struct { + values map[string]string +} + +func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} + +func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "advanced-tool-use-2025-11-20", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "advanced tool use is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "advanced-tool-use-2025-11-20", + []byte(`{"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err == nil { + t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform") + } + if err.Error() != "advanced tool use is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "tool-search-tool-2025-10-19", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + betaTokens, err := svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "advanced-tool-use-2025-11-20", + []byte(`{"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, token := range betaTokens { + if token == "tool-search-tool-2025-10-19" { + t.Fatalf("expected transformed Bedrock token to be filtered") + } + } +} + +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证: +// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, +// 但请求体包含 thinking 字段 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "interleaved-thinking-2025-05-14", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "thinking is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // header 中不带 beta token,但 body 中有 thinking 字段 + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", // 空 header + []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err == nil { + t.Fatal("expected body-injected interleaved-thinking to be blocked") + } + if err.Error() != "thinking is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证: +// 管理员 block 了 tool-search-tool,客户端不在 header 中带 beta token, +// 但请求体包含 tool search 工具 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "tool-search-tool-2025-10-19", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "tool search is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // header 中不带 beta token,但 body 中有 tool_search_tool 工具 + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", + []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-sonnet-4-6", + ) + if err == nil { + t.Fatal("expected body-injected tool-search-tool to be blocked") + } + if err.Error() != "tool search is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证: +// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。 +func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "computer-use-2025-11-24", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "computer use is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use + tokens, err := svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", + []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + found := false + for _, token := range tokens { + if token == "interleaved-thinking-2025-05-14" { + found = true + } + } + if !found { + t.Fatal("expected interleaved-thinking token to be present") + } +} diff --git a/backend/internal/service/gateway_service_bedrock_model_support_test.go b/backend/internal/service/gateway_service_bedrock_model_support_test.go new file mode 100644 index 00000000..aa8d4756 --- /dev/null +++ b/backend/internal/service/gateway_service_bedrock_model_support_test.go @@ -0,0 +1,48 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") { + t.Fatalf("expected default Bedrock alias to be supported") + } + + if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") { + t.Fatalf("expected unsupported alias to be rejected for Bedrock account") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "eu-west-1", + "model_mapping": map[string]any{ + "claude-sonnet-*": "claude-sonnet-4-6", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") { + t.Fatalf("expected matched custom mapping to be supported") + } + + if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") { + t.Fatalf("expected default Bedrock alias fallback to remain supported") + } + + if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") { + t.Fatalf("expected unsupported model to still be rejected") + } +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index cd690cbd..b1584827 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) { result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) _ = pr.Close() - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") require.NotNil(t, result) } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index b0e4d44f..8fffce1b 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -129,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } } + // 兼容遗留的 functions 和 function_call,转换为 tools 和 tool_choice + if functionsRaw, ok := reqBody["functions"]; ok { + if functions, k := functionsRaw.([]any); k { + tools := make([]any, 0, len(functions)) + for _, f := range functions { + tools = append(tools, map[string]any{ + "type": "function", + "function": f, + }) + } + reqBody["tools"] = tools + } + delete(reqBody, "functions") + result.Modified = true + } + + if fcRaw, ok := reqBody["function_call"]; ok { + if fcStr, ok := fcRaw.(string); ok { + // e.g. "auto", "none" + reqBody["tool_choice"] = fcStr + } else if fcObj, ok := fcRaw.(map[string]any); ok { + // e.g. {"name": "my_func"} + if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { + reqBody["tool_choice"] = map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + } + } + } + delete(reqBody, "function_call") + result.Modified = true + } + if normalizeCodexTools(reqBody) { result.Modified = true } @@ -303,6 +338,18 @@ func filterCodexInput(input []any, preserveReferences bool) []any { continue } typ, _ := m["type"].(string) + + // 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'" + fixIDPrefix := func(id string) string { + if id == "" || strings.HasPrefix(id, "fc") { + return id + } + if strings.HasPrefix(id, "call_") { + return "fc" + strings.TrimPrefix(id, "call_") + } + return "fc_" + id + } + if typ == "item_reference" { if !preserveReferences { continue @@ -311,6 +358,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any { for key, value := range m { newItem[key] = value } + if id, ok := newItem["id"].(string); ok && id != "" { + newItem["id"] = fixIDPrefix(id) + } filtered = append(filtered, newItem) continue } @@ -330,10 +380,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } if isCodexToolCallItemType(typ) { - if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" { + callID, ok := m["call_id"].(string) + if !ok || strings.TrimSpace(callID) == "" { if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" { + callID = id ensureCopy() - newItem["call_id"] = id + newItem["call_id"] = callID + } + } + + if callID != "" { + fixedCallID := fixIDPrefix(callID) + if fixedCallID != callID { + ensureCopy() + newItem["call_id"] = fixedCallID } } } @@ -344,6 +404,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any { if !isCodexToolCallItemType(typ) { delete(newItem, "call_id") } + } else { + if id, ok := newItem["id"].(string); ok && id != "" { + fixedID := fixIDPrefix(id) + if fixedID != id { + ensureCopy() + newItem["id"] = fixedID + } + } } filtered = append(filtered, newItem) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index c8097aed..df012d7c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -33,12 +33,12 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { first, ok := input[0].(map[string]any) require.True(t, ok) require.Equal(t, "item_reference", first["type"]) - require.Equal(t, "ref1", first["id"]) + require.Equal(t, "fc_ref1", first["id"]) // 校验 input[1] 为 map,确保后续字段断言安全。 second, ok := input[1].(map[string]any) require.True(t, ok) - require.Equal(t, "o1", second["id"]) + require.Equal(t, "fc_o1", second["id"]) } func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 9529462e..cd4d58fd 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -3,39 +3,68 @@ package service import ( "context" "errors" + "strings" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/stretchr/testify/require" ) type openAIRecordUsageLogRepoStub struct { UsageLogRepository - inserted bool - err error - calls int - lastLog *UsageLog + inserted bool + err error + calls int + lastLog *UsageLog + lastCtxErr error } func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { s.calls++ s.lastLog = log + s.lastCtxErr = ctx.Err() return s.inserted, s.err } +type openAIRecordUsageBillingRepoStub struct { + UsageBillingRepository + + result *UsageBillingApplyResult + err error + calls int + lastCmd *UsageBillingCommand + lastCtxErr error +} + +func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) { + s.calls++ + s.lastCmd = cmd + s.lastCtxErr = ctx.Err() + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &UsageBillingApplyResult{Applied: true}, nil +} + type openAIRecordUsageUserRepoStub struct { UserRepository deductCalls int deductErr error lastAmount float64 + lastCtxErr error } func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { s.deductCalls++ s.lastAmount = amount + s.lastCtxErr = ctx.Err() return s.deductErr } @@ -44,29 +73,35 @@ type openAIRecordUsageSubRepoStub struct { incrementCalls int incrementErr error + lastCtxErr error } func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { s.incrementCalls++ + s.lastCtxErr = ctx.Err() return s.incrementErr } type openAIRecordUsageAPIKeyQuotaStub struct { - quotaCalls int - rateLimitCalls int - err error - lastAmount float64 + quotaCalls int + rateLimitCalls int + err error + lastAmount float64 + lastQuotaCtxErr error + lastRateLimitCtxErr error } func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { s.quotaCalls++ s.lastAmount = cost + s.lastQuotaCtxErr = ctx.Err() return s.err } func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { s.rateLimitCalls++ s.lastAmount = cost + s.lastRateLimitCtxErr = ctx.Err() return s.err } @@ -93,23 +128,38 @@ func i64p(v int64) *int64 { func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { cfg := &config.Config{} cfg.Default.RateMultiplier = 1.1 + svc := NewOpenAIGatewayService( + nil, + usageRepo, + nil, + userRepo, + subRepo, + rateRepo, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + &DeferredService{}, + nil, + ) + svc.userGroupRateResolver = newUserGroupRateResolver( + rateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway.test", + ) + return svc +} - return &OpenAIGatewayService{ - usageLogRepo: usageRepo, - userRepo: userRepo, - userSubRepo: subRepo, - cfg: cfg, - billingService: NewBillingService(cfg, nil), - billingCacheService: &BillingCacheService{}, - deferredService: &DeferredService{}, - userGroupRateResolver: newUserGroupRateResolver( - rateRepo, - nil, - resolveUserGroupRateCacheTTL(cfg), - nil, - "service.openai_gateway.test", - ), - } +func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + svc.usageBillingRepo = billingRepo + return svc } func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { @@ -252,9 +302,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} - svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ Result: &OpenAIForwardResult{ @@ -272,11 +323,313 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin }) require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) require.Equal(t, 1, usageRepo.calls) require.Equal(t, 0, userRepo.deductCalls) require.Equal(t, 0, subRepo.incrementCalls) } +func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate_billing_key", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10045, + Quota: 100, + }, + User: &User{ID: 20045}, + Account: &Account{ID: 30045}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) { + usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_usage_log_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10041}, + User: &User{ID: 20041}, + Account: &Account{ID: 30041}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_not_persisted", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10043, + Quota: 100, + }, + User: &User{ID: 20043}, + Account: &Account{ID: 30043}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_ctx", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10042, + Quota: 100, + }, + User: &User{ID: 20042}, + Account: &Account{ID: 30042}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_repo_ctx", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10046}, + User: &User{ID: 20046}, + Account: &Account{ID: 30046}, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.NoError(t, billingRepo.lastCtxErr) + require.Equal(t, 1, usageRepo.calls) + require.NoError(t, usageRepo.lastCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`)) + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "openai_payload_hash", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "gpt-5", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10047}, + User: &User{ID: 20047}, + Account: &Account{ID: 30047}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "upstream-openai-volatile-456", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10049}, + User: &User{ID: 20049}, + Account: &Account{ID: 30049}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10050}, + User: &User{ID: 20050}, + Account: &Account{ID: 30050}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_fail", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10048}, + User: &User{ID: 20048}, + Account: &Account{ID: 30048}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} + func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 023e4ed4..72461544 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -301,6 +301,7 @@ var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICo type OpenAIGatewayService struct { accountRepo AccountRepository usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository userRepo UserRepository userSubRepo UserSubscriptionRepository cache GatewayCache @@ -338,6 +339,7 @@ type OpenAIGatewayService struct { func NewOpenAIGatewayService( accountRepo AccountRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -355,6 +357,7 @@ func NewOpenAIGatewayService( svc := &OpenAIGatewayService{ accountRepo: accountRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, cache: cache, @@ -2073,7 +2076,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2265,7 +2270,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2602,6 +2609,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( var firstTokenMs *int clientDisconnected := false sawDone := false + sawTerminalEvent := false upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) scanner := bufio.NewScanner(resp.Body) @@ -2621,6 +2629,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if trimmedData == "[DONE]" { sawDone = true } + if openAIStreamEventIsTerminal(trimmedData) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms @@ -2638,19 +2649,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } } if err := scanner.Err(); err != nil { - if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + if sawTerminalEvent { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", - "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, - upstreamRequestID, - err, - ctx.Err(), - ) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) } if errors.Is(err, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) @@ -2664,12 +2670,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } - if !clientDisconnected && !sawDone && ctx.Err() == nil { + if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( zap.String("component", "service.openai_gateway"), zap.Int64("account_id", account.ID), zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil @@ -3203,6 +3210,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage + sawTerminalEvent := false sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return @@ -3233,22 +3241,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") } } + if !sawTerminalEvent { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } return resultWithUsage(), nil } handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { if scanErr == nil { return nil, nil, false } + if sawTerminalEvent { + logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) + return resultWithUsage(), nil, true + } // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true } // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true } if errors.Is(scanErr, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) @@ -3271,6 +3284,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } dataBytes := []byte(data) + if openAIStreamEventIsTerminal(data) { + sawTerminalEvent = true + } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { @@ -3387,8 +3403,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") - return resultWithUsage(), nil + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -3486,11 +3501,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { return } - // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 + if len(data) < 72 { return } - if gjson.GetBytes(data, "type").String() != "response.completed" { + eventType := gjson.GetBytes(data, "type").String() + if eventType != "response.completed" && eventType != "response.done" { return } @@ -3843,14 +3859,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - APIKeyService APIKeyQuotaUpdater + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance @@ -3916,11 +3933,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: billingModel, ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, @@ -3961,29 +3979,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") return nil } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 43e2f39d..9e2f33f2 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) { } } -func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { +func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ @@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { } _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") - if err != nil { - t.Fatalf("expected nil error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") { + t.Fatalf("expected incomplete stream error, got %v", err) } if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) @@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { } } +func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 2, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 1, result.usage.CacheReadInputTokens) +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") @@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { require.Equal(t, 3, usage.InputTokens) require.Equal(t, 5, usage.OutputTokens) require.Equal(t, 2, usage.CacheReadInputTokens) + + // done 事件同样可能携带最终 usage + svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage) + require.Equal(t, 13, usage.InputTokens) + require.Equal(t, 15, usage.OutputTokens) + require.Equal(t, 4, usage.CacheReadInputTokens) } func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 6fbd2469..f51a7491 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -439,7 +439,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) headers := make(http.Header) headers.Set("Content-Type", "application/json") @@ -453,7 +453,14 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes resp := &http.Response{ StatusCode: http.StatusOK, Header: headers, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} @@ -895,7 +902,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t * } _, err := svc.Forward(context.Background(), c, account, originalBody) - require.NoError(t, err) + require.EqualError(t, err, "stream usage incomplete: missing terminal event") require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流")) require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info")) require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate")) @@ -911,11 +918,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t c.Request.Header.Set("x-stainless-timeout", "120000") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} svc := &OpenAIGatewayService{ @@ -952,11 +964,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured c.Request.Header.Set("x-stainless-timeout", "120000") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} svc := &OpenAIGatewayService{ diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 7295b13d..08eb397b 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, cfg, nil, nil, diff --git a/backend/internal/service/usage_billing.go b/backend/internal/service/usage_billing.go new file mode 100644 index 00000000..73b05743 --- /dev/null +++ b/backend/internal/service/usage_billing.go @@ -0,0 +1,110 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strings" +) + +var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required") +var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict") + +// UsageBillingCommand describes one billable request that must be applied at most once. +type UsageBillingCommand struct { + RequestID string + APIKeyID int64 + RequestFingerprint string + RequestPayloadHash string + + UserID int64 + AccountID int64 + SubscriptionID *int64 + AccountType string + Model string + ServiceTier string + ReasoningEffort string + BillingType int8 + InputTokens int + OutputTokens int + CacheCreationTokens int + CacheReadTokens int + ImageCount int + MediaType string + + BalanceCost float64 + SubscriptionCost float64 + APIKeyQuotaCost float64 + APIKeyRateLimitCost float64 + AccountQuotaCost float64 +} + +func (c *UsageBillingCommand) Normalize() { + if c == nil { + return + } + c.RequestID = strings.TrimSpace(c.RequestID) + if strings.TrimSpace(c.RequestFingerprint) == "" { + c.RequestFingerprint = buildUsageBillingFingerprint(c) + } +} + +func buildUsageBillingFingerprint(c *UsageBillingCommand) string { + if c == nil { + return "" + } + raw := fmt.Sprintf( + "%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f", + c.UserID, + c.AccountID, + c.APIKeyID, + strings.TrimSpace(c.AccountType), + strings.TrimSpace(c.Model), + strings.TrimSpace(c.ServiceTier), + strings.TrimSpace(c.ReasoningEffort), + c.BillingType, + c.InputTokens, + c.OutputTokens, + c.CacheCreationTokens, + c.CacheReadTokens, + c.ImageCount, + strings.TrimSpace(c.MediaType), + valueOrZero(c.SubscriptionID), + c.BalanceCost, + c.SubscriptionCost, + c.APIKeyQuotaCost, + c.APIKeyRateLimitCost, + c.AccountQuotaCost, + ) + if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" { + raw += "|" + payloadHash + } + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} + +func HashUsageRequestPayload(payload []byte) string { + if len(payload) == 0 { + return "" + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]) +} + +func valueOrZero(v *int64) int64 { + if v == nil { + return 0 + } + return *v +} + +type UsageBillingApplyResult struct { + Applied bool + APIKeyQuotaExhausted bool +} + +type UsageBillingRepository interface { + Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) +} diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 0fdbfd47..17f21bef 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -56,7 +56,8 @@ type cleanupRepoStub struct { } type dashboardRepoStub struct { - recomputeErr error + recomputeErr error + recomputeCalls int } func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time. } func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.recomputeErr } @@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti return nil } +func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } @@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) { } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { + dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ - DashboardAgg: config.DashboardAggregationConfig{Enabled: false}, + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ + DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} svc := NewUsageCleanupService(repo, nil, dashboard, cfg) @@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { + dashboardRepo := &dashboardRepoStub{} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} @@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) { diff --git a/backend/internal/service/usage_log_create_result.go b/backend/internal/service/usage_log_create_result.go new file mode 100644 index 00000000..1cd84f44 --- /dev/null +++ b/backend/internal/service/usage_log_create_result.go @@ -0,0 +1,82 @@ +package service + +import "errors" + +type usageLogCreateDisposition int + +const ( + usageLogCreateDispositionUnknown usageLogCreateDisposition = iota + usageLogCreateDispositionNotPersisted + usageLogCreateDispositionDropped +) + +type UsageLogCreateError struct { + err error + disposition usageLogCreateDisposition +} + +func (e *UsageLogCreateError) Error() string { + if e == nil || e.err == nil { + return "usage log create error" + } + return e.err.Error() +} + +func (e *UsageLogCreateError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func MarkUsageLogCreateNotPersisted(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionNotPersisted, + } +} + +func MarkUsageLogCreateDropped(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionDropped, + } +} + +func IsUsageLogCreateNotPersisted(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionNotPersisted +} + +func IsUsageLogCreateDropped(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionDropped +} + +func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool { + if inserted { + return true + } + if err == nil { + return false + } + return !IsUsageLogCreateNotPersisted(err) +} diff --git a/backend/migrations/071_add_usage_billing_dedup.sql b/backend/migrations/071_add_usage_billing_dedup.sql new file mode 100644 index 00000000..acc28459 --- /dev/null +++ b/backend/migrations/071_add_usage_billing_dedup.sql @@ -0,0 +1,13 @@ +-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来 +-- 幂等执行:可重复运行 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup ( + id BIGSERIAL PRIMARY KEY, + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key + ON usage_billing_dedup (request_id, api_key_id); diff --git a/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql new file mode 100644 index 00000000..965a3412 --- /dev/null +++ b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql @@ -0,0 +1,7 @@ +-- usage_billing_dedup 是按时间追加写入的幂等窄表。 +-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。 +-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。 + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin + ON usage_billing_dedup + USING BRIN (created_at); diff --git a/backend/migrations/073_add_usage_billing_dedup_archive.sql b/backend/migrations/073_add_usage_billing_dedup_archive.sql new file mode 100644 index 00000000..d156d4eb --- /dev/null +++ b/backend/migrations/073_add_usage_billing_dedup_archive.sql @@ -0,0 +1,10 @@ +-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive ( + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (request_id, api_key_id) +); diff --git a/deploy/Dockerfile b/deploy/Dockerfile index ffe815e5..0f4f1de9 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -105,7 +105,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/deploy/build_image.sh b/deploy/build_image.sh old mode 100755 new mode 100644 diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 0ef397df..d404ac0b 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -154,7 +154,7 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index 7676fb97..df0ccfcc 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -94,7 +94,7 @@ services: - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index e5c97bf8..acd21fd9 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -146,7 +146,7 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh old mode 100755 new mode 100644 diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 4393dda3..85200506 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -11,6 +11,7 @@ import type { GroupStat, ApiKeyUsageTrendPoint, UserUsageTrendPoint, + UserSpendingRankingResponse, UsageRequestType } from '@/types' @@ -201,6 +202,11 @@ export interface UserTrendResponse { granularity: string } +export interface UserSpendingRankingParams + extends Pick { + limit?: number +} + /** * Get user usage trend data * @param params - Query parameters for filtering @@ -213,6 +219,20 @@ export async function getUserUsageTrend(params?: UserTrendParams): Promise { + const { data } = await apiClient.get('/admin/dashboard/users-ranking', { + params + }) + return data +} + export interface BatchUserUsageStats { user_id: number today_actual_cost: number @@ -271,6 +291,7 @@ export const dashboardAPI = { getSnapshotV2, getApiKeyUsageTrend, getUserUsageTrend, + getUserSpendingRanking, getBatchUsersUsage, getBatchApiKeysUsage } diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 8423c1b9..1ac96ed6 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -232,7 +232,7 @@
-
+
+ + + +
@@ -896,7 +956,7 @@ -
+
+ +
+
+ + +
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockSessionTokenHint') }}

+
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+ + +
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+
@@ -2671,7 +3014,7 @@ interface TempUnschedRuleForm { // State const step = ref(1) const submitting = ref(false) -const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selection for account category +const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') @@ -2704,6 +3047,19 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist' const antigravityWhitelistModels = ref([]) const antigravityModelMappings = ref([]) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) +const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) + +// Bedrock credentials +const bedrockAccessKeyId = ref('') +const bedrockSecretAccessKey = ref('') +const bedrockSessionToken = ref('') +const bedrockRegion = ref('us-east-1') +const bedrockForceGlobal = ref(false) + +// Bedrock API Key credentials +const bedrockApiKeyValue = ref('') +const bedrockApiKeyRegion = ref('us-east-1') +const bedrockApiKeyForceGlobal = ref(false) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -2868,6 +3224,10 @@ const isOAuthFlow = computed(() => { if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { return false } + // Bedrock 类型不需要 OAuth 流程 + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') { + return false + } return accountCategory.value === 'oauth-based' }) @@ -2935,6 +3295,11 @@ watch( form.type = 'apikey' return } + // Bedrock 类型 + if (form.platform === 'anthropic' && category === 'bedrock') { + form.type = 'bedrock' as AccountType + return + } if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { @@ -2972,6 +3337,13 @@ watch( antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } + // Reset Bedrock fields when switching platforms + bedrockAccessKeyId.value = '' + bedrockSecretAccessKey.value = '' + bedrockSessionToken.value = '' + bedrockRegion.value = 'us-east-1' + bedrockForceGlobal.value = false + bedrockApiKeyForceGlobal.value = false // Reset Anthropic/Antigravity-specific settings when switching to other platforms if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false @@ -3541,6 +3913,84 @@ const handleSubmit = async () => { return } + // For Bedrock type, create directly + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!bedrockAccessKeyId.value.trim()) { + appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired')) + return + } + if (!bedrockSecretAccessKey.value.trim()) { + appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired')) + return + } + if (!bedrockRegion.value.trim()) { + appStore.showError(t('admin.accounts.bedrockRegionRequired')) + return + } + + const credentials: Record = { + aws_access_key_id: bedrockAccessKeyId.value.trim(), + aws_secret_access_key: bedrockSecretAccessKey.value.trim(), + aws_region: bedrockRegion.value.trim(), + } + if (bedrockSessionToken.value.trim()) { + credentials.aws_session_token = bedrockSessionToken.value.trim() + } + if (bedrockForceGlobal.value) { + credentials.aws_force_global = 'true' + } + + // Model mapping + const modelMapping = buildModelMappingObject( + modelRestrictionMode.value, allowedModels.value, modelMappings.value + ) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') + + await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials) + return + } + + // For Bedrock API Key type, create directly + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!bedrockApiKeyValue.value.trim()) { + appStore.showError(t('admin.accounts.bedrockApiKeyRequired')) + return + } + + const credentials: Record = { + api_key: bedrockApiKeyValue.value.trim(), + aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1', + } + if (bedrockApiKeyForceGlobal.value) { + credentials.aws_force_global = 'true' + } + + // Model mapping + const modelMapping = buildModelMappingObject( + modelRestrictionMode.value, allowedModels.value, modelMappings.value + ) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') + + await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials) + return + } + // For Antigravity upstream type, create directly if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { if (!form.name.trim()) { diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 1f2e988c..b18e9db6 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -563,6 +563,233 @@
+ +
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}

+
+
+ + +

{{ t('admin.accounts.bedrockSessionTokenHint') }}

+
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}

+
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+
@@ -1529,6 +1756,7 @@ const baseUrlHint = computed(() => { }) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) +const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) // Model mapping type interface ModelMapping { @@ -1547,6 +1775,17 @@ interface TempUnschedRuleForm { const submitting = ref(false) const editBaseUrl = ref('https://api.anthropic.com') const editApiKey = ref('') +// Bedrock credentials +const editBedrockAccessKeyId = ref('') +const editBedrockSecretAccessKey = ref('') +const editBedrockSessionToken = ref('') +const editBedrockRegion = ref('') +const editBedrockForceGlobal = ref(false) + +// Bedrock API Key credentials +const editBedrockApiKeyValue = ref('') +const editBedrockApiKeyRegion = ref('') +const editBedrockApiKeyForceGlobal = ref(false) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -1889,6 +2128,58 @@ watch( } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'bedrock' && newAccount.credentials) { + const bedrockCreds = newAccount.credentials as Record + editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' + editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' + editBedrockSecretAccessKey.value = '' + editBedrockSessionToken.value = '' + + // Load model mappings for bedrock + const existingMappings = bedrockCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + } else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) { + const bedrockApiKeyCreds = newAccount.credentials as Record + editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1' + editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true' + editBedrockApiKeyValue.value = '' + + // Load model mappings for bedrock-apikey + const existingMappings = bedrockApiKeyCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' @@ -2431,6 +2722,70 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'bedrock') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim() + newCredentials.aws_region = editBedrockRegion.value.trim() + if (editBedrockForceGlobal.value) { + newCredentials.aws_force_global = 'true' + } else { + delete newCredentials.aws_force_global + } + + // Only update secrets if user provided new values + if (editBedrockSecretAccessKey.value.trim()) { + newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim() + } + if (editBedrockSessionToken.value.trim()) { + newCredentials.aws_session_token = editBedrockSessionToken.value.trim() + } + + // Model mapping + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + + updatePayload.credentials = newCredentials + } else if (props.account.type === 'bedrock-apikey') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1' + if (editBedrockApiKeyForceGlobal.value) { + newCredentials.aws_force_global = 'true' + } else { + delete newCredentials.aws_force_global + } + + // Only update API key if user provided new value + if (editBedrockApiKeyValue.value.trim()) { + newCredentials.api_key = editBedrockApiKeyValue.value.trim() + } + + // Model mapping + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + updatePayload.credentials = newCredentials } else { // For oauth/setup-token types, only update intercept_warmup_requests if changed diff --git a/frontend/src/components/charts/ModelDistributionChart.vue b/frontend/src/components/charts/ModelDistributionChart.vue index 6f80e541..5db5a14f 100644 --- a/frontend/src/components/charts/ModelDistributionChart.vue +++ b/frontend/src/components/charts/ModelDistributionChart.vue @@ -2,38 +2,72 @@

- {{ t('admin.dashboard.modelDistribution') }} + {{ !enableRankingView || activeView === 'model_distribution' + ? t('admin.dashboard.modelDistribution') + : t('admin.dashboard.spendingRankingTitle') }}

-
- - + + +
+
+ + +
-
+ +
-
+
@@ -77,6 +111,70 @@
+
+ {{ t('admin.dashboard.noDataAvailable') }} +
+ +
+ +
+
+ {{ t('admin.dashboard.failedToLoad') }} +
+
+
+ +
+
+ + + + + + + + + + + + + + + + + +
{{ t('admin.dashboard.spendingRankingUser') }}{{ t('admin.dashboard.spendingRankingRequests') }}{{ t('admin.dashboard.spendingRankingTokens') }}{{ t('admin.dashboard.spendingRankingSpend') }}
+
+ + #{{ index + 1 }} + + + {{ getRankingUserLabel(item) }} + +
+
+ {{ formatNumber(item.requests) }} + + {{ formatTokens(item.tokens) }} + + ${{ formatCost(item.actual_cost) }} +
+
+