Merge remote-tracking branch 'origin/main' into fix/enc_coot

# Conflicts:
#	backend/internal/service/openai_gateway_service.go
This commit is contained in:
InCerry
2026-03-14 13:04:24 +08:00
82 changed files with 8571 additions and 394 deletions

View File

@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -163,9 +164,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)

View File

@@ -7,7 +7,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2 v1.41.2
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
@@ -66,7 +66,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.1 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect

View File

@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=

View File

@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
UsageLogsDays int `mapstructure:"usage_logs_days"`
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
@@ -1301,6 +1302,7 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}

View File

@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
}
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
}
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
}
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation dedup retention",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation dedup retention smaller than usage logs",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageLogsDays = 30
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },

View File

@@ -27,10 +27,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
)
// Redeem type constants
@@ -113,3 +115,27 @@ var DefaultAntigravityModelMapping = map[string]string{
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值ResolveBedrockModelID 会根据账号配置的
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
// Claude Sonnet
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
// Claude Haiku
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
}

View File

@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`

View File

@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
func parseRankingLimit(raw string) int {
limit, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || limit <= 0 {
return 12
}
if limit > 50 {
return 50
}
return limit
}
// GetUserSpendingRanking handles getting user spending ranking data.
// GET /api/v1/admin/dashboard/users-ranking
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
keyRaw, _ := json.Marshal(struct {
Start string `json:"start"`
End string `json:"end"`
Limit int `json:"limit"`
}{
Start: startTime.UTC().Format(time.RFC3339),
End: endTime.UTC().Format(time.RFC3339),
Limit: limit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
if err != nil {
response.Error(c, 500, "Failed to get user spending ranking")
return
}
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}
dashboardUsersRankingCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {

View File

@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
trendStream *bool
modelRequestType *int16
modelStream *bool
rankingLimit int
ranking []usagestats.UserSpendingRankingItem
rankingTotal float64
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
return []usagestats.ModelStat{}, nil
}
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
ctx context.Context,
startTime, endTime time.Time,
limit int,
) (*usagestats.UserSpendingRankingResponse, error) {
s.rankingLimit = limit
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
return router
}
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{
ranking: []usagestats.UserSpendingRankingItem{
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
},
rankingTotal: 88.8,
}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
}

View File

@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
return
}
req.Code = strings.TrimSpace(req.Code)
// 向后兼容:旧版调用方(如 Sub2ApiPay不传 type 字段,默认当作 balance 充值处理。
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
if req.Type == "" {
req.Type = "balance"
}
if req.Type == "subscription" {
if req.GroupID == nil {
response.BadRequest(c, "group_id is required for subscription type")
return
}
if req.ValidityDays <= 0 {
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
return
}
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.

View File

@@ -0,0 +1,135 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
// parameter-validation layer that runs before any service call.
func newCreateAndRedeemHandler() *RedeemHandler {
return &RedeemHandler{
adminService: newStubAdminService(),
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
}
}
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
// status code. For cases that pass validation and proceed into the service layer,
// a panic may occur (because RedeemService internals are nil); this is expected
// and treated as "validation passed" (returns 0 to indicate panic).
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
jsonBytes, err := json.Marshal(body)
require.NoError(t, err)
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
c.Request.Header.Set("Content-Type", "application/json")
defer func() {
if r := recover(); r != nil {
// Panic means we passed validation and entered service layer (expected for minimal stub).
code = 0
}
}()
handler.CreateAndRedeem(c)
return w.Code
}
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
// 不传 type 字段时应默认 balance不触发 subscription 校验。
// 验证通过后进入 service 层会 panic返回 0说明默认值生效。
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-default",
"value": 10.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"omitting type should default to balance and pass validation")
}
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-no-group",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"validity_days": 30,
// group_id 缺失
})
assert.Equal(t, http.StatusBadRequest, code)
}
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
cases := []struct {
name string
validityDays int
}{
{"zero", 0},
{"negative", -1},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-bad-days-" + tc.name,
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": tc.validityDays,
})
assert.Equal(t, http.StatusBadRequest, code)
})
}
}
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-valid",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": 31,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"valid subscription params should pass validation")
}
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
h := newCreateAndRedeemHandler()
// balance 类型不传 group_id 和 validity_days不应报 400
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-no-extras",
"type": "balance",
"value": 50.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}

View File

@@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),

View File

@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // usageBillingRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo

View File

@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,

View File

@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
@@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),

View File

@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService用于测试 SelectAccountForModel
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}

View File

@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),

View File

@@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
@@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,

View File

@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}

View File

@@ -96,12 +96,28 @@ type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserSpendingRankingItem represents a user spending ranking row.
type UserSpendingRankingItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
ActualCost float64 `json:"actual_cost"` // 实际扣除
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
type UserSpendingRankingResponse struct {
Ranking []UserSpendingRankingItem `json:"ranking"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`

View File

@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
sql sqlExecutor
}
const usageLogsCleanupBatchSize = 10000
const usageBillingDedupCleanupBatchSize = 10000
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
dayEnd = dayEnd.Add(24 * time.Hour)
}
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid
FROM usage_logs
WHERE created_at < $1
LIMIT $2
)
DELETE FROM usage_logs
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageLogsCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageLogsCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
FROM usage_billing_dedup
WHERE created_at < $1
LIMIT $2
), archived AS (
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
SELECT request_id, api_key_id, request_fingerprint, created_at
FROM victims
ON CONFLICT (request_id, api_key_id) DO NOTHING
)
DELETE FROM usage_billing_dedup
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageBillingDedupCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {

View File

@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.Quota != 0 {
create.SetQuota(k.Quota)
}
if k.QuotaUsed != 0 {
create.SetQuotaUsed(k.QuotaUsed)
}
if k.RateLimit5h != 0 {
create.SetRateLimit5h(k.RateLimit5h)
}
if k.RateLimit1d != 0 {
create.SetRateLimit1d(k.RateLimit1d)
}
if k.RateLimit7d != 0 {
create.SetRateLimit7d(k.RateLimit7d)
}
if k.Usage5h != 0 {
create.SetUsage5h(k.Usage5h)
}
if k.Usage1d != 0 {
create.SetUsage1d(k.Usage1d)
}
if k.Usage7d != 0 {
create.SetUsage7d(k.Usage7d)
}
if k.Window5hStart != nil {
create.SetWindow5hStart(*k.Window5hStart)
}
if k.Window1dStart != nil {
create.SetWindow1dStart(*k.Window1dStart)
}
if k.Window7dStart != nil {
create.SetWindow7dStart(*k.Window7dStart)
}
if k.ExpiresAt != nil {
create.SetExpiresAt(*k.ExpiresAt)
}
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}

View File

@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()

View File

@@ -0,0 +1,308 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageBillingRepository struct {
db *sql.DB
}
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
return &usageBillingRepository{db: sqlDB}
}
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
if cmd == nil {
return &service.UsageBillingApplyResult{}, nil
}
if r == nil || r.db == nil {
return nil, errors.New("usage billing repository db is nil")
}
cmd.Normalize()
if cmd.RequestID == "" {
return nil, service.ErrUsageBillingRequestIDRequired
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
if err != nil {
return nil, err
}
if !applied {
return &service.UsageBillingApplyResult{Applied: false}, nil
}
result := &service.UsageBillingApplyResult{Applied: true}
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tx = nil
return result, nil
}
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
var id int64
err := tx.QueryRowContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
VALUES ($1, $2, $3)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
var existingFingerprint string
if err := tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
return false, err
}
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if err != nil {
return false, err
}
var archivedFingerprint string
err = tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup_archive
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
if err == nil {
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return true, nil
}
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
return err
}
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
return err
}
}
if cmd.APIKeyQuotaCost > 0 {
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
if err != nil {
return err
}
result.APIKeyQuotaExhausted = exhausted
}
if cmd.APIKeyRateLimitCost > 0 {
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
return err
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}
}
return nil
}
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
var exhausted bool
err := tx.QueryRowContext(ctx, `
UPDATE api_keys
SET quota_used = quota_used + $1,
status = CASE
WHEN quota > 0
AND status = $3
AND quota_used < quota
AND quota_used + $1 >= quota
THEN $4
ELSE status
END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
if errors.Is(err, sql.ErrNoRows) {
return false, service.ErrAPIKeyNotFound
}
if err != nil {
return false, err
}
return exhausted, nil
}
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, cost, apiKeyID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAPIKeyNotFound
}
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
} else {
if err := rows.Err(); err != nil {
return err
}
return service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
}
}
return nil
}

View File

@@ -0,0 +1,279 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-" + uuid.NewString(),
Name: "billing",
Quota: 1,
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
BalanceCost: 1.25,
APIKeyQuotaCost: 1.25,
APIKeyRateLimitCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result1)
require.True(t, result1.Applied)
require.True(t, result1.APIKeyQuotaExhausted)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result2)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan(&quotaUsed))
require.InDelta(t, 1.25, quotaUsed, 0.000001)
var usage5h float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
require.InDelta(t, 1.25, usage5h, 0.000001)
var status string
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
var dedupCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
require.Equal(t, 1, dedupCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
group := mustCreateGroup(t, client, &service.Group{
Name: "usage-billing-group-" + uuid.NewString(),
Platform: service.PlatformAnthropic,
SubscriptionType: service.SubscriptionTypeSubscription,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
GroupID: &group.ID,
Key: "sk-usage-billing-sub-" + uuid.NewString(),
Name: "billing-sub",
})
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: 0,
SubscriptionID: &subscription.ID,
SubscriptionCost: 2.5,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var dailyUsage float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
require.InDelta(t, 2.5, dailyUsage, 0.000001)
}
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
Name: "billing-conflict",
})
requestID := uuid.NewString()
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
})
require.NoError(t, err)
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 2.50,
})
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
}
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-account-" + uuid.NewString(),
Name: "billing-account",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-quota-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: map[string]any{
"quota_limit": 100.0,
},
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 3.5,
})
require.NoError(t, err)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan(&quotaUsed))
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
oldRequestID := "dedup-old-" + uuid.NewString()
newRequestID := "dedup-new-" + uuid.NewString()
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
newCreatedAt := time.Now().UTC().Add(-time.Hour)
_, err := integrationDB.ExecContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
`,
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
newRequestID, strings.Repeat("b", 64), newCreatedAt,
)
require.NoError(t, err)
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
var oldCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
require.Equal(t, 0, oldCount)
var newCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
require.Equal(t, 1, newCount)
var archivedCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
require.Equal(t, 1, archivedCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-archive-" + uuid.NewString(),
Name: "billing-archive",
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
_, err = integrationDB.ExecContext(ctx, `
UPDATE usage_billing_dedup
SET created_at = $1
WHERE request_id = $2 AND api_key_id = $3
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
require.NoError(t, err)
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,8 @@ package repository
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -14,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
s.Require().NotZero(log.ID)
}
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
const total = 16
results := make([]bool, total)
errs := make([]error, total)
logs := make([]*service.UsageLog, total)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
i := i
logs[i] = &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
go func() {
defer wg.Done()
results[i], errs[i] = repo.Create(ctx, logs[i])
}()
}
wg.Wait()
for i := 0; i < total; i++ {
require.NoError(t, errs[i])
require.True(t, results[i])
require.NotZero(t, logs[i].ID)
}
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
require.Equal(t, total, count)
}
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
inserted1, err1 := repo.Create(ctx, log1)
inserted2, err2 := repo.Create(ctx, log2)
require.NoError(t, err1)
require.NoError(t, err2)
require.True(t, inserted1)
require.False(t, inserted2)
require.Equal(t, log1.ID, log2.ID)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
requestID := uuid.NewString()
const total = 8
batch := make([]usageLogCreateRequest, 0, total)
logs := make([]*service.UsageLog, 0, total)
for i := 0; i < total; i++ {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
logs = append(logs, log)
batch = append(batch, usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
resultCh: make(chan usageLogCreateResult, 1),
})
}
repo.flushCreateBatch(integrationDB, batch)
insertedCount := 0
var firstID int64
for idx, req := range batch {
res := <-req.resultCh
require.NoError(t, res.err)
if res.inserted {
insertedCount++
}
require.NotZero(t, logs[idx].ID)
if idx == 0 {
firstID = logs[idx].ID
} else {
require.Equal(t, firstID, logs[idx].ID)
}
}
require.Equal(t, 1, insertedCount)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
require.NoError(t, repo.CreateBestEffort(ctx, log1))
require.NoError(t, repo.CreateBestEffort(ctx, log2))
require.Eventually(t, func() bool {
var count int
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
return err == nil && count == 1
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
err := repo.CreateBestEffort(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.Error(t, err)
require.True(t, service.IsUsageLogCreateDropped(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
repo.createBatchCh <- usageLogCreateRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := repo.createBatched(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
errCh <- err
}()
req := <-repo.createBatchCh
require.NotNil(t, req.shared)
cancel()
err := <-errCh
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
}
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
req.shared.state.Store(usageLogCreateStateCanceled)
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
res := <-req.resultCh
require.False(t, res.inserted)
require.Error(t, res.err)
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})

View File

@@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
mock.ExpectQuery("WITH user_spend AS \\(").
WithArgs(start, end, 12).
WillReturnRows(rows)
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
require.NoError(t, err)
require.Equal(t, &usagestats.UserSpendingRankingResponse{
Ranking: []usagestats.UserSpendingRankingItem{
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
},
TotalActualCost: 40.0,
}, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string

View File

@@ -3,8 +3,11 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
})
}
}
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-batch-no-update",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 5,
TotalCost: 1.2,
ActualCost: 1.2,
CreatedAt: time.Now().UTC(),
}
prepared := prepareUsageLogInsert(log)
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
})
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
}

View File

@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageBillingRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,

View File

@@ -1635,6 +1635,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
logs := r.userLogs[userID]
if len(logs) == 0 {

View File

@@ -192,6 +192,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)

View File

@@ -412,6 +412,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil
}
if len(rawMapping) == 0 {
@@ -764,6 +765,14 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false
}
func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
}
func (a *Account) IsBedrockAPIKey() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}

View File

@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
// API Key 账号测试连接时也需要应用通配符模型映射。
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
testModelID = account.GetMappedModel(testModelID)
}
// Bedrock accounts use a separate test path
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
}
// Determine authentication method and API URL
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
if !ok {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
}
testModelID = resolvedModelID
// Set SSE headers (test UI expects SSE)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
bedrockPayload := map[string]any{
"anthropic_version": "bedrock-2023-05-31",
"messages": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "text",
"text": "hi",
},
},
},
},
"max_tokens": 256,
"temperature": 1,
}
bedrockBody, _ := json.Marshal(bedrockPayload)
// Use non-streaming endpoint (response is standard Claude JSON)
apiURL := BuildBedrockURL(region, testModelID, false)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
// Sign or set auth based on account type
if account.IsBedrockAPIKey() {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
req.Header.Set("Authorization", "Bearer "+apiKey)
} else {
signer, err := NewBedrockSignerFromAccount(account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
}
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Bedrock non-streaming response is standard Claude JSON, extract the text
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(body, &result); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
}
text := ""
if len(result.Content) > 0 {
text = result.Content[0].Text
}
if text == "" {
text = "(empty response)"
}
s.sendEvent(c, TestEvent{Type: "content", Text: text})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()

View File

@@ -47,6 +47,7 @@ type UsageLogRepository interface {
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)

View File

@@ -0,0 +1,607 @@
package service
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const defaultBedrockRegion = "us-east-1"
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
func BedrockCrossRegionPrefix(region string) string {
switch {
case strings.HasPrefix(region, "us-gov"):
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
case strings.HasPrefix(region, "us-"):
return "us"
case strings.HasPrefix(region, "eu-"):
return "eu"
case region == "ap-northeast-1":
return "jp" // 日本区域使用独立的 jp 前缀AWS 官方定义)
case region == "ap-southeast-2":
return "au" // 澳大利亚区域使用独立的 au 前缀AWS 官方定义)
case strings.HasPrefix(region, "ap-"):
return "apac" // 其余亚太区域使用通用 apac 前缀
case strings.HasPrefix(region, "ca-"):
return "us" // 加拿大区域使用 us 前缀的跨区域推理
case strings.HasPrefix(region, "sa-"):
return "us" // 南美区域使用 us 前缀的跨区域推理
default:
return "us"
}
}
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
// 特殊值 region="global" 强制使用 global. 前缀
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
var targetPrefix string
if region == "global" {
targetPrefix = "global"
} else {
targetPrefix = BedrockCrossRegionPrefix(region)
}
for _, p := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, p) {
if p == targetPrefix+"." {
return modelID // 前缀已匹配,无需替换
}
return targetPrefix + "." + modelID[len(p):]
}
}
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
return modelID
}
func bedrockRuntimeRegion(account *Account) string {
if account == nil {
return defaultBedrockRegion
}
if region := account.GetCredential("aws_region"); region != "" {
return region
}
return defaultBedrockRegion
}
func shouldForceBedrockGlobal(account *Account) bool {
return account != nil && account.GetCredential("aws_force_global") == "true"
}
func isRegionalBedrockModelID(modelID string) bool {
for _, prefix := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, prefix) {
return true
}
}
return false
}
func isLikelyBedrockModelID(modelID string) bool {
lower := strings.ToLower(strings.TrimSpace(modelID))
if lower == "" {
return false
}
if strings.HasPrefix(lower, "arn:") {
return true
}
for _, prefix := range []string{
"anthropic.",
"amazon.",
"meta.",
"mistral.",
"cohere.",
"ai21.",
"deepseek.",
"stability.",
"writer.",
"nova.",
} {
if strings.HasPrefix(lower, prefix) {
return true
}
}
return isRegionalBedrockModelID(lower)
}
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return "", false, false
}
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
return mapped, true, true
}
if isRegionalBedrockModelID(modelID) {
return modelID, true, true
}
if isLikelyBedrockModelID(modelID) {
return modelID, false, true
}
return "", false, false
}
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
// It applies account model_mapping first, then default Bedrock aliases, and finally
// adjusts Anthropic cross-region prefixes to match the account region.
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
if account == nil {
return "", false
}
mappedModel := account.GetMappedModel(requestedModel)
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
if !ok {
return "", false
}
if shouldAdjustRegion {
targetRegion := bedrockRuntimeRegion(account)
if shouldForceBedrockGlobal(account) {
targetRegion = "global"
}
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
}
return modelID, true
}
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
// stream=true 时使用 invoke-with-response-stream 端点
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
func BuildBedrockURL(region, modelID string, stream bool) string {
if region == "" {
region = defaultBedrockRegion
}
encodedModelID := url.PathEscape(modelID)
// url.PathEscape 不编码冒号RFC 允许 path 中出现 ":"
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
if stream {
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
}
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
// 1. 注入 anthropic_version
// 2. 注入 anthropic_beta从客户端 anthropic-beta 头解析)
// 3. 移除 Bedrock 不支持的字段model, stream, output_format, output_config
// 4. 移除工具定义中的 custom 字段Claude Code 会发送 custom: {defer_loading: true}
// 5. 清理 cache_control 中 Bedrock 不支持的字段scope, ttl
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
}
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
var err error
// 注入 anthropic_versionBedrock 要求)
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
if err != nil {
return nil, fmt.Errorf("inject anthropic_version: %w", err)
}
// 注入 anthropic_betaBedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
// 1. 从客户端 anthropic-beta header 解析
// 2. 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
if len(betaTokens) > 0 {
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
if err != nil {
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
}
}
// 移除 model 字段Bedrock 通过 URL 指定模型)
body, err = sjson.DeleteBytes(body, "model")
if err != nil {
return nil, fmt.Errorf("remove model field: %w", err)
}
// 移除 stream 字段Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
body, err = sjson.DeleteBytes(body, "stream")
if err != nil {
return nil, fmt.Errorf("remove stream field: %w", err)
}
// 转换 output_formatBedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message
// 参考 litellm: _convert_output_format_to_inline_schema()
body = convertOutputFormatToInlineSchema(body)
// 移除 output_config 字段Bedrock Invoke 不支持)
body, err = sjson.DeleteBytes(body, "output_config")
if err != nil {
return nil, fmt.Errorf("remove output_config field: %w", err)
}
// 移除工具定义中的 custom 字段
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
body = removeCustomFieldFromTools(body)
// 清理 cache_control 中 Bedrock 不支持的字段
body = sanitizeBedrockCacheControl(body, modelID)
return body, nil
}
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
betaTokens := parseAnthropicBetaHeader(betaHeader)
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
return filterBedrockBetaTokens(betaTokens)
}
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
// Bedrock Invoke 不支持 output_format 参数litellm 的做法是将 schema 追加到用户消息中
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
func convertOutputFormatToInlineSchema(body []byte) []byte {
outputFormat := gjson.GetBytes(body, "output_format")
if !outputFormat.Exists() || !outputFormat.IsObject() {
return body
}
// 先从请求体中移除 output_format
body, _ = sjson.DeleteBytes(body, "output_format")
schema := outputFormat.Get("schema")
if !schema.Exists() {
return body
}
// 找到最后一条 user message
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
msgArr := messages.Array()
lastUserIdx := -1
for i := len(msgArr) - 1; i >= 0; i-- {
if msgArr[i].Get("role").String() == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx < 0 {
return body
}
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
if err != nil {
return body
}
content := msgArr[lastUserIdx].Get("content")
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
if content.IsArray() {
// 追加一个 text block 到 content 数组末尾
idx := len(content.Array())
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
} else if content.Type == gjson.String {
// content 是纯字符串,转换为数组格式
originalText := content.String()
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
{"type": "text", "text": originalText},
{"type": "text", "text": string(schemaJSON)},
})
}
return body
}
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
func removeCustomFieldFromTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return body
}
var err error
for i := range tools.Array() {
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
if err != nil {
// 删除失败不影响整体流程,跳过
continue
}
}
return body
}
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h"
func isBedrockClaude45OrNewer(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
// Bedrock 不支持的字段:
// - scopeBedrock 不支持(如 "global" 跨请求缓存)
// - ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
isClaude45 := isBedrockClaude45OrNewer(modelID)
// 清理 system 数组中的 cache_control
systemArr := gjson.GetBytes(body, "system")
if systemArr.Exists() && systemArr.IsArray() {
for i, item := range systemArr.Array() {
if !item.IsObject() {
continue
}
cc := item.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
}
}
// 清理 messages 中的 cache_control
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
for mi, msg := range messages.Array() {
if !msg.IsObject() {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
for ci, block := range content.Array() {
if !block.IsObject() {
continue
}
cc := block.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
}
}
return body
}
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
// Bedrock 不支持 scope如 "global"
if cc.Get("scope").Exists() {
body, _ = sjson.DeleteBytes(body, basePath+".scope")
}
// ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
ttl := cc.Get("ttl")
if ttl.Exists() {
shouldRemove := true
if isClaude45 {
v := ttl.String()
if v == "5m" || v == "1h" {
shouldRemove = false
}
}
if shouldRemove {
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
}
}
return body
}
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
func parseAnthropicBetaHeader(header string) []string {
header = strings.TrimSpace(header)
if header == "" {
return nil
}
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
var parsed []any
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
tokens := make([]string, 0, len(parsed))
for _, item := range parsed {
token := strings.TrimSpace(fmt.Sprint(item))
if token != "" {
tokens = append(tokens, token)
}
}
return tokens
}
}
var tokens []string
for _, part := range strings.Split(header, ",") {
t := strings.TrimSpace(part)
if t != "" {
tokens = append(tokens, t)
}
}
return tokens
}
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
var bedrockSupportedBetaTokens = map[string]bool{
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
"context-management-2025-06-27": true,
"compact-2026-01-12": true,
"interleaved-thinking-2025-05-14": true,
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
}
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
// Anthropic 直接 API 使用通用头Bedrock Invoke 需要特定的替代头
var bedrockBetaTokenTransforms = map[string]string{
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
}
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
//
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
seen := make(map[string]bool, len(tokens))
for _, t := range tokens {
seen[t] = true
}
inject := func(token string) {
if !seen[token] {
tokens = append(tokens, token)
seen[token] = true
}
}
// 检测 thinking / interleaved thinking
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
if gjson.GetBytes(body, "thinking").Exists() {
inject("interleaved-thinking-2025-05-14")
}
// 检测 computer_use 工具
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() {
toolSearchUsed := false
programmaticToolCallingUsed := false
inputExamplesUsed := false
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, "computer_20") {
inject("computer-use-2025-11-24")
}
if isBedrockToolSearchType(toolType) {
toolSearchUsed = true
}
if hasCodeExecutionAllowedCallers(tool) {
programmaticToolCallingUsed = true
}
if hasInputExamples(tool) {
inputExamplesUsed = true
}
}
if programmaticToolCallingUsed || inputExamplesUsed {
// programmatic tool calling 和 input examples 需要 advanced-tool-use
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
inject("advanced-tool-use-2025-11-20")
}
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
// 纯 tool search无 programmatic/inputExamples时直接注入 Bedrock 特定头,
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
if !programmaticToolCallingUsed && !inputExamplesUsed {
inject("tool-search-tool-2025-10-19")
} else {
inject("advanced-tool-use-2025-11-20")
}
}
}
return tokens
}
func isBedrockToolSearchType(toolType string) bool {
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
}
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
allowedCallers := tool.Get("allowed_callers")
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
return true
}
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
}
func hasInputExamples(tool gjson.Result) bool {
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
return true
}
arr := tool.Get("function.input_examples")
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
}
func containsStringInJSONArray(result gjson.Result, target string) bool {
if !result.Exists() || !result.IsArray() {
return false
}
for _, item := range result.Array() {
if item.String() == target {
return true
}
}
return false
}
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
// 目前仅 Claude Opus/Sonnet 4.5+ 支持Haiku 不支持
func bedrockModelSupportsToolSearch(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
// Haiku 不支持 tool search
if strings.Contains(lower, "haiku") {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool
// 2. 过滤掉 Bedrock 不支持的 token如 output-128k, files-api, structured-outputs 等)
// 3. 自动关联 tool-examples当 tool-search-tool 存在时)
func filterBedrockBetaTokens(tokens []string) []string {
seen := make(map[string]bool, len(tokens))
var result []string
for _, t := range tokens {
// 应用转换规则
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
t = replacement
}
// 只保留白名单中的 token且去重
if bedrockSupportedBetaTokens[t] && !seen[t] {
result = append(result, t)
seen[t] = true
}
}
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
result = append(result, "tool-examples-2025-10-29")
}
return result
}

View File

@@ -0,0 +1,659 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
// anthropic_version 应被注入
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
// model 和 stream 应被移除
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
// max_tokens 应保留
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
}
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
t.Run("schema inlined into last user message array content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
// schema 应内联到最后一条 user message 的 content 数组末尾
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "text", contentArr[1].Get("type").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
})
t.Run("schema inlined into string content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
})
t.Run("no schema field just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
t.Run("no messages just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
}
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
}
func TestRemoveCustomFieldFromTools(t *testing.T) {
input := `{
"tools": [
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
{"name":"tool2","description":"desc2"},
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
]
}`
result := removeCustomFieldFromTools([]byte(input))
tools := gjson.GetBytes(result, "tools").Array()
require.Len(t, tools, 3)
// custom 应被移除
assert.False(t, tools[0].Get("custom").Exists())
// name/description 应保留
assert.Equal(t, "tool1", tools[0].Get("name").String())
assert.Equal(t, "desc1", tools[0].Get("description").String())
// 没有 custom 的工具不受影响
assert.Equal(t, "tool2", tools[1].Get("name").String())
// 第三个工具的 custom 也应被移除
assert.False(t, tools[2].Get("custom").Exists())
assert.Equal(t, "tool3", tools[2].Get("name").String())
}
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}]}`
result := removeCustomFieldFromTools([]byte(input))
// 无 tools 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// scope 应被移除
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
// type 应保留
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// 旧模型Claude 3.5)不支持 ttl
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// Claude 4.5+ 支持 "5m" 和 "1h"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
}`
// Claude 4.5 不支持 "10m"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
}
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
input := `{
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// 无 cache_control 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestIsBedrockClaude45OrNewer(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-6-v1", true},
{"us.anthropic.claude-sonnet-4-6", true},
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
{"anthropic.claude-3-opus-20240229-v1:0", false},
{"anthropic.claude-3-haiku-20240307-v1:0", false},
// 未来版本应自动支持
{"us.anthropic.claude-sonnet-5-0-v1", true},
{"us.anthropic.claude-opus-4-7-v1", true},
// 旧版本
{"anthropic.claude-opus-4-1-v1", false},
{"anthropic.claude-sonnet-4-0-v1", false},
// 非 Claude 模型
{"amazon.nova-pro-v1", false},
{"meta.llama3-70b", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
})
}
}
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
// 模拟一个完整的 Claude Code 请求
input := `{
"model": "claude-opus-4-6",
"stream": true,
"max_tokens": 16384,
"output_format": {"type": "json", "schema": {"result": "string"}},
"output_config": {"max_tokens": 100},
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
],
"tools": [
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
]
}`
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
require.NoError(t, err)
// 基本字段
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
// anthropic_beta 应包含所有 beta tokens
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, betaArr, 3)
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
// output_format 应被移除schema 内联到最后一条 user message
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
// content 数组:原始 text block + 内联 schema block
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "hello", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
// tools 中的 custom 应被移除
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
// cache_control: scope 应被移除ttl 在 Claude 4.6 上保留合法值
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("empty beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("single beta token", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
t.Run("json array beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
}
func TestParseAnthropicBetaHeader(t *testing.T) {
assert.Nil(t, parseAnthropicBetaHeader(""))
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
}
func TestFilterBedrockBetaTokens(t *testing.T) {
t.Run("supported tokens pass through", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, tokens, result)
})
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
tokens := []string{"advanced-tool-use-2025-11-20"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
// tool-examples 自动关联
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
result := filterBedrockBetaTokens(tokens)
count := 0
for _, t := range result {
if t == "tool-examples-2025-10-29" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("empty input returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens(nil)
assert.Nil(t, result)
})
t.Run("all unsupported returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
assert.Nil(t, result)
})
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
}
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"advanced-tool-use-2025-11-20")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
})
}
func TestBedrockCrossRegionPrefix(t *testing.T) {
tests := []struct {
region string
expect string
}{
// US regions
{"us-east-1", "us"},
{"us-east-2", "us"},
{"us-west-1", "us"},
{"us-west-2", "us"},
// GovCloud
{"us-gov-east-1", "us-gov"},
{"us-gov-west-1", "us-gov"},
// EU regions
{"eu-west-1", "eu"},
{"eu-west-2", "eu"},
{"eu-west-3", "eu"},
{"eu-central-1", "eu"},
{"eu-central-2", "eu"},
{"eu-north-1", "eu"},
{"eu-south-1", "eu"},
// APAC regions
{"ap-northeast-1", "jp"},
{"ap-northeast-2", "apac"},
{"ap-southeast-1", "apac"},
{"ap-southeast-2", "au"},
{"ap-south-1", "apac"},
// Canada / South America fallback to us
{"ca-central-1", "us"},
{"sa-east-1", "us"},
// Unknown defaults to us
{"me-south-1", "us"},
}
for _, tt := range tests {
t.Run(tt.region, func(t *testing.T) {
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
})
}
}
func TestResolveBedrockModelID(t *testing.T) {
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
require.True(t, ok)
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
})
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "ap-southeast-2",
"model_mapping": map[string]any{
"claude-*": "claude-opus-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
require.True(t, ok)
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
})
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
"aws_force_global": "true",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
require.True(t, ok)
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
})
t.Run("direct bedrock model id passes through", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
require.True(t, ok)
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
})
t.Run("unsupported alias returns false", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
assert.False(t, ok)
})
}
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
t.Run("no duplicate when already present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
count := 0
for _, t := range result {
if t == "interleaved-thinking-2025-05-14" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "computer-use-2025-11-24")
})
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 混合场景使用 advanced-tool-use后续由 filter 转换为 tool-search-tool
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
})
t.Run("no injection for regular tools", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("no injection when no features detected", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("preserves existing tokens", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "context-1m-2025-08-07")
assert.Contains(t, result, "compact-2026-01-12")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
}
func TestResolveBedrockBetaTokens(t *testing.T) {
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
}
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
found := false
for _, v := range arr {
if v.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
assert.True(t, found, "interleaved-thinking should be auto-injected")
})
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
names := make([]string, len(arr))
for i, v := range arr {
names[i] = v.String()
}
assert.Contains(t, names, "context-1m-2025-08-07")
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
})
}
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
tests := []struct {
name string
modelID string
region string
expect string
}{
// US region — no change needed
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
// EU region — replace us → eu
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
// APAC region — jp and au have dedicated prefixes per AWS docs
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
// eu → us (user manually set eu prefix, moved to us region)
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
// global prefix — replace to match region
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
// No known prefix — leave unchanged
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
// GovCloud — uses independent us-gov prefix
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
// Force global (special region value)
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
})
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
type BedrockSigner struct {
credentials aws.Credentials
region string
signer *v4.Signer
}
// NewBedrockSigner 创建 BedrockSigner
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
return &BedrockSigner{
credentials: aws.Credentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
},
region: region,
signer: v4.NewSigner(),
}
}
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
accessKeyID := account.GetCredential("aws_access_key_id")
if accessKeyID == "" {
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
}
secretAccessKey := account.GetCredential("aws_secret_access_key")
if secretAccessKey == "" {
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
}
region := account.GetCredential("aws_region")
if region == "" {
region = defaultBedrockRegion
}
sessionToken := account.GetCredential("aws_session_token") // 可选
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
}
// SignRequest 对 HTTP 请求进行 SigV4 签名
// 重要约束调用此方法前req 应只包含 AWS 相关的 header如 Content-Type、Accept
// 非 AWS header如 anthropic-beta会参与签名计算如果 Bedrock 服务端不识别这些 header
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept因此是安全的。
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
payloadHash := sha256Hash(body)
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
}
func sha256Hash(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,35 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_access_key_id": "test-akid",
"aws_secret_access_key": "test-secret",
},
}
signer, err := NewBedrockSignerFromAccount(account)
require.NoError(t, err)
require.NotNil(t, signer)
assert.Equal(t, defaultBedrockRegion, signer.region)
}
func TestFilterBetaTokens(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
filterSet := map[string]struct{}{
"tool-search-tool-2025-10-19": {},
}
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
assert.Nil(t, filterBetaTokens(nil, filterSet))
}

View File

@@ -0,0 +1,414 @@
package service
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"hash/crc32"
"io"
"net/http"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
func (s *GatewayService) handleBedrockStreamingResponse(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
model string,
) (*streamingResult, error) {
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
c.Header("x-request-id", v)
}
usage := &ClaudeUsage{}
var firstTokenMs *int
clientDisconnected := false
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
// 每个帧结构total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
// 但更实用的方式是使用行扫描找 JSON chunks因为 Bedrock 的响应在二进制帧中。
// 我们使用 EventStream decoder 来正确解析。
decoder := newBedrockEventStreamDecoder(resp.Body)
type decodeEvent struct {
payload []byte
err error
}
events := make(chan decodeEvent, 16)
done := make(chan struct{})
sendEvent := func(ev decodeEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt atomic.Int64
lastReadAt.Store(time.Now().UnixNano())
go func() {
defer close(events)
for {
payload, err := decoder.Decode()
if err != nil {
if err == io.EOF {
return
}
_ = sendEvent(decodeEvent{err: err})
return
}
lastReadAt.Store(time.Now().UnixNano())
if !sendEvent(decodeEvent{payload: payload}) {
return
}
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
if !clientDisconnected {
flusher.Flush()
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
}
// payload 是 JSON提取 chunk.bytesbase64 编码的 Claude SSE 事件数据)
sseData := extractBedrockChunkData(ev.payload)
if sseData == nil {
continue
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
// 同时移除该字段避免透传给客户端
sseData = transformBedrockInvocationMetrics(sseData)
// 解析 SSE 事件数据提取 usage
s.parseSSEUsagePassthrough(string(sseData), usage)
// 确定 SSE event type
eventType := gjson.GetBytes(sseData, "type").String()
// 写入标准 SSE 格式
if !clientDisconnected {
var writeErr error
if eventType != "" {
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
} else {
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
}
if writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
}
case <-intervalCh:
lastRead := time.Unix(0, lastReadAt.Load())
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
func extractBedrockChunkData(payload []byte) []byte {
b64 := gjson.GetBytes(payload, "bytes").String()
if b64 == "" {
return nil
}
decoded, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil
}
return decoded
}
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
//
// Bedrock Invoke 返回的 message_delta 事件可能包含:
//
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
//
// 转换为:
//
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
func transformBedrockInvocationMetrics(data []byte) []byte {
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
if !metrics.Exists() || !metrics.IsObject() {
return data
}
// 移除 Bedrock 特有字段
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
// 如果已有标准 usage 字段,不覆盖
if gjson.GetBytes(data, "usage").Exists() {
return data
}
// 转换 camelCase → snake_case 写入 usage
inputTokens := metrics.Get("inputTokenCount")
outputTokens := metrics.Get("outputTokenCount")
if inputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
}
if outputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
}
return data
}
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
// EventStream 帧格式:
//
// [total_byte_length: 4 bytes]
// [headers_byte_length: 4 bytes]
// [prelude_crc: 4 bytes]
// [headers: variable]
// [payload: variable]
// [message_crc: 4 bytes]
type bedrockEventStreamDecoder struct {
reader *bufio.Reader
}
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
return &bedrockEventStreamDecoder{
reader: bufio.NewReaderSize(r, 64*1024),
}
}
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
for {
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
prelude := make([]byte, 12)
if _, err := io.ReadFull(d.reader, prelude); err != nil {
return nil, err
}
// 验证 prelude CRCAWS EventStream 使用标准 CRC32 / IEEE
preludeCRC := bedrockReadUint32(prelude[8:12])
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
}
totalLength := bedrockReadUint32(prelude[0:4])
headersLength := bedrockReadUint32(prelude[4:8])
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
}
// 读取 headers + payload + message_crc
remaining := int(totalLength) - 12
if remaining <= 0 {
continue
}
data := make([]byte, remaining)
if _, err := io.ReadFull(d.reader, data); err != nil {
return nil, err
}
// 验证 message CRC覆盖 prelude + headers + payload
messageCRC := bedrockReadUint32(data[len(data)-4:])
h := crc32.New(crc32IEEETable)
_, _ = h.Write(prelude)
_, _ = h.Write(data[:len(data)-4])
if h.Sum32() != messageCRC {
return nil, fmt.Errorf("eventstream message CRC mismatch")
}
// 解析 headers
headers := data[:headersLength]
payload := data[headersLength : len(data)-4] // 去掉 message_crc
// 从 headers 中提取 :event-type
eventType := extractEventStreamHeaderValue(headers, ":event-type")
// 只处理 chunk 事件
if eventType == "chunk" {
// payload 是完整的 JSON包含 bytes 字段
return payload, nil
}
// 检查异常事件
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
if exceptionType != "" {
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
}
messageType := extractEventStreamHeaderValue(headers, ":message-type")
if messageType == "exception" || messageType == "error" {
return nil, fmt.Errorf("bedrock error: %s", string(payload))
}
// 跳过其他事件类型(如 initial-response
}
}
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
// EventStream header 格式:
//
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
//
// value_type = 7 表示 string 类型,前 2 bytes 为长度
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
pos := 0
for pos < len(headers) {
if pos >= len(headers) {
break
}
nameLen := int(headers[pos])
pos++
if pos+nameLen > len(headers) {
break
}
name := string(headers[pos : pos+nameLen])
pos += nameLen
if pos >= len(headers) {
break
}
valueType := headers[pos]
pos++
switch valueType {
case 7: // string
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2
if pos+valueLen > len(headers) {
return ""
}
value := string(headers[pos : pos+valueLen])
pos += valueLen
if name == targetName {
return value
}
case 0: // bool true
if name == targetName {
return "true"
}
case 1: // bool false
if name == targetName {
return "false"
}
case 2: // byte
pos++
if name == targetName {
return ""
}
case 3: // short
pos += 2
if name == targetName {
return ""
}
case 4: // int
pos += 4
if name == targetName {
return ""
}
case 5: // long
pos += 8
if name == targetName {
return ""
}
case 6: // bytes
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2 + valueLen
case 8: // timestamp
pos += 8
case 9: // uuid
pos += 16
default:
return "" // 未知类型,无法继续解析
}
}
return ""
}
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
func bedrockReadUint32(b []byte) uint32 {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func bedrockReadUint16(b []byte) uint16 {
return uint16(b[0])<<8 | uint16(b[1])
}

View File

@@ -0,0 +1,261 @@
package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"hash/crc32"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestExtractBedrockChunkData(t *testing.T) {
t.Run("valid base64 payload", func(t *testing.T) {
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
b64 := base64.StdEncoding.EncodeToString([]byte(original))
payload := []byte(`{"bytes":"` + b64 + `"}`)
result := extractBedrockChunkData(payload)
require.NotNil(t, result)
assert.JSONEq(t, original, string(result))
})
t.Run("empty bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
assert.Nil(t, result)
})
t.Run("no bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
assert.Nil(t, result)
})
t.Run("invalid base64", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
assert.Nil(t, result)
})
}
func TestTransformBedrockInvocationMetrics(t *testing.T) {
t.Run("converts metrics to usage", func(t *testing.T) {
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// amazon-bedrock-invocationMetrics should be removed
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
// usage should be set
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
// original fields preserved
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
})
t.Run("no metrics present", func(t *testing.T) {
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
result := transformBedrockInvocationMetrics([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("does not overwrite existing usage", func(t *testing.T) {
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// metrics removed but existing usage preserved
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
})
}
func TestExtractEventStreamHeaderValue(t *testing.T) {
// Build a header with :event-type = "chunk" (string type = 7)
buildStringHeader := func(name, value string) []byte {
var buf bytes.Buffer
// name length (1 byte)
_ = buf.WriteByte(byte(len(name)))
// name
_, _ = buf.WriteString(name)
// value type (7 = string)
_ = buf.WriteByte(7)
// value length (2 bytes, big-endian)
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
// value
_, _ = buf.WriteString(value)
return buf.Bytes()
}
t.Run("find string header", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
})
t.Run("header not found", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("multiple headers", func(t *testing.T) {
var buf bytes.Buffer
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
headers := buf.Bytes()
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("empty headers", func(t *testing.T) {
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
})
}
func TestBedrockEventStreamDecoder(t *testing.T) {
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
buildFrame := func(eventType string, payload []byte) []byte {
// Build headers
var headersBuf bytes.Buffer
// :event-type header
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7) // string type
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
_, _ = headersBuf.WriteString(eventType)
// :message-type header
_ = headersBuf.WriteByte(byte(len(":message-type")))
_, _ = headersBuf.WriteString(":message-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
_, _ = headersBuf.WriteString("event")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
// total = 12 (prelude) + headers + payload + 4 (message_crc)
totalLen := uint32(12 + len(headers) + len(payload) + 4)
// Prelude: total_length(4) + headers_length(4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
// Build frame: prelude + prelude_crc + headers + payload
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
// Message CRC covers everything before itself
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
return frame.Bytes()
}
t.Run("decode chunk event", func(t *testing.T) {
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
frame := buildFrame("chunk", payload)
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, payload, result)
})
t.Run("skip non-chunk events", func(t *testing.T) {
// Write initial-response followed by chunk
var buf bytes.Buffer
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
decoder := newBedrockEventStreamDecoder(&buf)
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, chunkPayload, result)
})
t.Run("EOF on empty input", func(t *testing.T) {
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
_, err := decoder.Decode()
assert.Equal(t, io.EOF, err)
})
t.Run("corrupted prelude CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the prelude CRC (bytes 8-11)
frame[8] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
t.Run("corrupted message CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the message CRC (last 4 bytes)
frame[len(frame)-1] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "message CRC mismatch")
})
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
payload := []byte(`{"bytes":"dGVzdA=="}`)
var headersBuf bytes.Buffer
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
_, _ = headersBuf.WriteString("chunk")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
totalLen := uint32(12 + len(headers) + len(payload) + 4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
}
func TestBuildBedrockURL(t *testing.T) {
t.Run("stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
})
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
})
t.Run("model ID without colon", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
})
}

View File

@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
if usageErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
if dedupErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
}
if aggErr == nil && usageErr == nil && dedupErr == nil {
s.lastRetentionCleanup.Store(now)
}
}

View File

@@ -12,12 +12,18 @@ import (
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
recomputeCalls int
cleanupUsageCalls int
cleanupDedupCalls int
ensurePartitionCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
cleanupDedupErr error
ensurePartitionErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.AggregateRange(ctx, start, end)
}
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
s.cleanupUsageCalls++
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
s.cleanupDedupCalls++
return s.cleanupDedupErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
s.ensurePartitionCalls++
return s.ensurePartitionErr
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupUsageCalls)
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
UsageBillingDedupDays: 2,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.ensurePartitionCalls)
require.Equal(t, 1, repo.aggregateCalls)
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {

View File

@@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil
}
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
if err != nil {
return nil, fmt.Errorf("get user spending ranking: %w", err)
}
return ranking, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {

View File

@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}

View File

@@ -29,10 +29,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
)
// Redeem type constants

View File

@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
},
}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
}
account := &Account{
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
},
}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
require.Equal(t, 5, result.usage.OutputTokens)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
"",
}, "\n"))),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
_ = pr.Close()
<-done
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 9, result.usage.InputTokens)
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 8, result.usage.InputTokens)

View File

@@ -0,0 +1,371 @@
//go:build unit
package service
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return NewGatewayService(
nil,
nil,
usageRepo,
nil,
userRepo,
subRepo,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
nil,
&DeferredService{},
nil,
nil,
nil,
nil,
nil,
)
}
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.usageBillingRepo = billingRepo
return svc
}
type openAIRecordUsageBestEffortLogRepoStub struct {
UsageLogRepository
bestEffortErr error
createErr error
bestEffortCalls int
createCalls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
s.bestEffortCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.bestEffortErr
}
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.createCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return false, s.createErr
}
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 501,
Quota: 100,
},
User: &User{ID: 601},
Account: &Account{ID: 701},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_hash",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_fallback",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_not_persisted",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 503,
Quota: 100,
},
User: &User{ID: 603},
Account: &Account{ID: 703},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
Result: &ForwardResult{
RequestID: "gateway_long_context_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 12,
OutputTokens: 8,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 502,
Quota: 100,
},
User: &User{ID: 602},
Account: &Account{ID: 702},
LongContextThreshold: 200000,
LongContextMultiplier: 2,
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 504},
User: &User{ID: 604},
Account: &Account{ID: 704},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "upstream-volatile-456",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 506},
User: &User{ID: 606},
Account: &Account{ID: 706},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 507},
User: &User{ID: 607},
Account: &Account{ID: 707},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_drop_usage_log",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 508},
User: &User{ID: 608},
Account: &Account{ID: 708},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.bestEffortCalls)
require.Equal(t, 0, usageRepo.createCalls)
}
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_billing_fail",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 505},
User: &User{ID: 605},
Account: &Account{ID: 705},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,267 @@
package service
import (
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type betaPolicySettingRepoStub struct {
values map[string]string
}
func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "advanced-tool-use-2025-11-20",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "advanced tool use is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"advanced-tool-use-2025-11-20",
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err == nil {
t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform")
}
if err.Error() != "advanced tool use is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "tool-search-tool-2025-10-19",
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
betaTokens, err := svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"advanced-tool-use-2025-11-20",
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, token := range betaTokens {
if token == "tool-search-tool-2025-10-19" {
t.Fatalf("expected transformed Bedrock token to be filtered")
}
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
// 管理员 block 了 interleaved-thinking客户端不在 header 中带该 token
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "interleaved-thinking-2025-05-14",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "thinking is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// header 中不带 beta token但 body 中有 thinking 字段
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"", // 空 header
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err == nil {
t.Fatal("expected body-injected interleaved-thinking to be blocked")
}
if err.Error() != "thinking is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证:
// 管理员 block 了 tool-search-tool客户端不在 header 中带 beta token
// 但请求体包含 tool search 工具 → 自动注入后应被 block。
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "tool-search-tool-2025-10-19",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "tool search is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// header 中不带 beta token但 body 中有 tool_search_tool 工具
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"",
[]byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-sonnet-4-6",
)
if err == nil {
t.Fatal("expected body-injected tool-search-tool to be blocked")
}
if err.Error() != "tool search is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证:
// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。
func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "computer-use-2025-11-24",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "computer use is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// body 中有 thinking会注入 interleaved-thinking但 block 规则只针对 computer-use
tokens, err := svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"",
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
found := false
for _, token := range tokens {
if token == "interleaved-thinking-2025-05-14" {
found = true
}
}
if !found {
t.Fatal("expected interleaved-thinking token to be present")
}
}

View File

@@ -0,0 +1,48 @@
package service
import "testing"
func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") {
t.Fatalf("expected default Bedrock alias to be supported")
}
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
t.Fatalf("expected unsupported alias to be rejected for Bedrock account")
}
}
func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
"model_mapping": map[string]any{
"claude-sonnet-*": "claude-sonnet-4-6",
},
},
}
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") {
t.Fatalf("expected matched custom mapping to be supported")
}
if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") {
t.Fatalf("expected default Bedrock alias fallback to remain supported")
}
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
t.Fatalf("expected unsupported model to still be rejected")
}
}

View File

@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}

View File

@@ -129,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
}
}
// 兼容遗留的 functions 和 function_call转换为 tools 和 tool_choice
if functionsRaw, ok := reqBody["functions"]; ok {
if functions, k := functionsRaw.([]any); k {
tools := make([]any, 0, len(functions))
for _, f := range functions {
tools = append(tools, map[string]any{
"type": "function",
"function": f,
})
}
reqBody["tools"] = tools
}
delete(reqBody, "functions")
result.Modified = true
}
if fcRaw, ok := reqBody["function_call"]; ok {
if fcStr, ok := fcRaw.(string); ok {
// e.g. "auto", "none"
reqBody["tool_choice"] = fcStr
} else if fcObj, ok := fcRaw.(map[string]any); ok {
// e.g. {"name": "my_func"}
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{
"type": "function",
"function": map[string]any{
"name": name,
},
}
}
}
delete(reqBody, "function_call")
result.Modified = true
}
if normalizeCodexTools(reqBody) {
result.Modified = true
}
@@ -303,6 +338,18 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
continue
}
typ, _ := m["type"].(string)
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
fixIDPrefix := func(id string) string {
if id == "" || strings.HasPrefix(id, "fc") {
return id
}
if strings.HasPrefix(id, "call_") {
return "fc" + strings.TrimPrefix(id, "call_")
}
return "fc_" + id
}
if typ == "item_reference" {
if !preserveReferences {
continue
@@ -311,6 +358,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
for key, value := range m {
newItem[key] = value
}
if id, ok := newItem["id"].(string); ok && id != "" {
newItem["id"] = fixIDPrefix(id)
}
filtered = append(filtered, newItem)
continue
}
@@ -330,10 +380,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
if isCodexToolCallItemType(typ) {
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
callID, ok := m["call_id"].(string)
if !ok || strings.TrimSpace(callID) == "" {
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
callID = id
ensureCopy()
newItem["call_id"] = id
newItem["call_id"] = callID
}
}
if callID != "" {
fixedCallID := fixIDPrefix(callID)
if fixedCallID != callID {
ensureCopy()
newItem["call_id"] = fixedCallID
}
}
}
@@ -344,6 +404,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
} else {
if id, ok := newItem["id"].(string); ok && id != "" {
fixedID := fixIDPrefix(id)
if fixedID != id {
ensureCopy()
newItem["id"] = fixedID
}
}
}
filtered = append(filtered, newItem)

View File

@@ -33,12 +33,12 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "item_reference", first["type"])
require.Equal(t, "ref1", first["id"])
require.Equal(t, "fc_ref1", first["id"])
// 校验 input[1] 为 map确保后续字段断言安全。
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "o1", second["id"])
require.Equal(t, "fc_o1", second["id"])
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {

View File

@@ -3,39 +3,68 @@ package service
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
type openAIRecordUsageLogRepoStub struct {
UsageLogRepository
inserted bool
err error
calls int
lastLog *UsageLog
inserted bool
err error
calls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.inserted, s.err
}
type openAIRecordUsageBillingRepoStub struct {
UsageBillingRepository
result *UsageBillingApplyResult
err error
calls int
lastCmd *UsageBillingCommand
lastCtxErr error
}
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
s.calls++
s.lastCmd = cmd
s.lastCtxErr = ctx.Err()
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &UsageBillingApplyResult{Applied: true}, nil
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
deductCalls int
deductErr error
lastAmount float64
lastCtxErr error
}
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++
s.lastAmount = amount
s.lastCtxErr = ctx.Err()
return s.deductErr
}
@@ -44,29 +73,35 @@ type openAIRecordUsageSubRepoStub struct {
incrementCalls int
incrementErr error
lastCtxErr error
}
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++
s.lastCtxErr = ctx.Err()
return s.incrementErr
}
type openAIRecordUsageAPIKeyQuotaStub struct {
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
lastQuotaCtxErr error
lastRateLimitCtxErr error
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++
s.lastAmount = cost
s.lastQuotaCtxErr = ctx.Err()
return s.err
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++
s.lastAmount = cost
s.lastRateLimitCtxErr = ctx.Err()
return s.err
}
@@ -93,23 +128,38 @@ func i64p(v int64) *int64 {
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
svc := NewOpenAIGatewayService(
nil,
usageRepo,
nil,
userRepo,
subRepo,
rateRepo,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
&DeferredService{},
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
)
return svc
}
return &OpenAIGatewayService{
usageLogRepo: usageRepo,
userRepo: userRepo,
userSubRepo: subRepo,
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
svc.usageBillingRepo = billingRepo
return svc
}
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
@@ -252,9 +302,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
@@ -272,11 +323,313 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate_billing_key",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10045,
Quota: 100,
},
User: &User{ID: 20045},
Account: &Account{ID: 30045},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_usage_log_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10041},
User: &User{ID: 20041},
Account: &Account{ID: 30041},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_not_persisted",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10043,
Quota: 100,
},
User: &User{ID: 20043},
Account: &Account{ID: 30043},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_ctx",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10042,
Quota: 100,
},
User: &User{ID: 20042},
Account: &Account{ID: 30042},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_repo_ctx",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10046},
User: &User{ID: 20046},
Account: &Account{ID: 30046},
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.NoError(t, billingRepo.lastCtxErr)
require.Equal(t, 1, usageRepo.calls)
require.NoError(t, usageRepo.lastCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "openai_payload_hash",
Usage: OpenAIUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "gpt-5",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10047},
User: &User{ID: 20047},
Account: &Account{ID: 30047},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "upstream-openai-volatile-456",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10049},
User: &User{ID: 20049},
Account: &Account{ID: 30049},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10050},
User: &User{ID: 20050},
Account: &Account{ID: 30050},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_billing_fail",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10048},
User: &User{ID: 20048},
Account: &Account{ID: 30048},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}

View File

@@ -301,6 +301,7 @@ var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICo
type OpenAIGatewayService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageBillingRepo UsageBillingRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
@@ -338,6 +339,7 @@ type OpenAIGatewayService struct {
func NewOpenAIGatewayService(
accountRepo AccountRepository,
usageLogRepo UsageLogRepository,
usageBillingRepo UsageBillingRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
@@ -355,6 +357,7 @@ func NewOpenAIGatewayService(
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
@@ -2119,7 +2122,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
httpInvalidEncryptedContentRetryTried := false
for {
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
@@ -2326,7 +2331,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err
}
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
@@ -2663,6 +2670,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
var firstTokenMs *int
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
scanner := bufio.NewScanner(resp.Body)
@@ -2682,6 +2690,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
@@ -2699,19 +2710,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
}
if err := scanner.Err(); err != nil {
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
if sawTerminalEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
err,
ctx.Err(),
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
@@ -2725,12 +2731,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if !clientDisconnected && !sawDone && ctx.Err() == nil {
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
@@ -3264,6 +3271,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 否则下游 SDK例如 OpenCode会因为类型校验失败而报错。
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
@@ -3294,22 +3302,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
}
}
if !sawTerminalEvent {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
if sawTerminalEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
return resultWithUsage(), nil, true
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
return resultWithUsage(), nil, true
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
@@ -3332,6 +3345,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
dataBytes := []byte(data)
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
@@ -3448,8 +3464,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
return resultWithUsage(), nil
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
@@ -3547,11 +3562,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
if len(data) < 72 {
return
}
if gjson.GetBytes(data, "type").String() != "response.completed" {
eventType := gjson.GetBytes(data, "type").String()
if eventType != "response.completed" && eventType != "response.done" {
return
}
@@ -4007,14 +4023,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater
Result *OpenAIForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
}
// RecordUsage records usage and deducts balance
@@ -4080,11 +4097,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
RequestID: requestID,
Model: billingModel,
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
@@ -4125,29 +4143,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
return nil
}

View File

@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
t.Fatalf("expected incomplete stream error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
}
}
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 2, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 1, result.usage.CacheReadInputTokens)
}
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
require.Equal(t, 3, usage.InputTokens)
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 2, usage.CacheReadInputTokens)
// done 事件同样可能携带最终 usage
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
require.Equal(t, 13, usage.InputTokens)
require.Equal(t, 15, usage.OutputTokens)
require.Equal(t, 4, usage.CacheReadInputTokens)
}
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {

View File

@@ -439,7 +439,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
headers := make(http.Header)
headers.Set("Content-Type", "application/json")
@@ -453,7 +453,14 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes
resp := &http.Response{
StatusCode: http.StatusOK,
Header: headers,
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.output_text.delta","delta":"h"}`,
"",
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}
upstream := &httpUpstreamRecorder{resp: resp}
@@ -895,7 +902,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.NoError(t, err)
require.EqualError(t, err, "stream usage incomplete: missing terminal event")
require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流"))
require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info"))
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate"))
@@ -911,11 +918,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t
c.Request.Header.Set("x-stainless-timeout", "120000")
c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-default"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
@@ -952,11 +964,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured
c.Request.Header.Set("x-stainless-timeout", "120000")
c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-allow"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{

View File

@@ -335,7 +335,7 @@ func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedC
},
}
body := []byte(`{"model":"gpt-5.1","stream":true,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
@@ -604,6 +604,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
cfg,
nil,
nil,

View File

@@ -0,0 +1,110 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
)
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
// UsageBillingCommand describes one billable request that must be applied at most once.
type UsageBillingCommand struct {
RequestID string
APIKeyID int64
RequestFingerprint string
RequestPayloadHash string
UserID int64
AccountID int64
SubscriptionID *int64
AccountType string
Model string
ServiceTier string
ReasoningEffort string
BillingType int8
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
ImageCount int
MediaType string
BalanceCost float64
SubscriptionCost float64
APIKeyQuotaCost float64
APIKeyRateLimitCost float64
AccountQuotaCost float64
}
func (c *UsageBillingCommand) Normalize() {
if c == nil {
return
}
c.RequestID = strings.TrimSpace(c.RequestID)
if strings.TrimSpace(c.RequestFingerprint) == "" {
c.RequestFingerprint = buildUsageBillingFingerprint(c)
}
}
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
if c == nil {
return ""
}
raw := fmt.Sprintf(
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
c.UserID,
c.AccountID,
c.APIKeyID,
strings.TrimSpace(c.AccountType),
strings.TrimSpace(c.Model),
strings.TrimSpace(c.ServiceTier),
strings.TrimSpace(c.ReasoningEffort),
c.BillingType,
c.InputTokens,
c.OutputTokens,
c.CacheCreationTokens,
c.CacheReadTokens,
c.ImageCount,
strings.TrimSpace(c.MediaType),
valueOrZero(c.SubscriptionID),
c.BalanceCost,
c.SubscriptionCost,
c.APIKeyQuotaCost,
c.APIKeyRateLimitCost,
c.AccountQuotaCost,
)
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
raw += "|" + payloadHash
}
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
func HashUsageRequestPayload(payload []byte) string {
if len(payload) == 0 {
return ""
}
sum := sha256.Sum256(payload)
return hex.EncodeToString(sum[:])
}
func valueOrZero(v *int64) int64 {
if v == nil {
return 0
}
return *v
}
type UsageBillingApplyResult struct {
Applied bool
APIKeyQuotaExhausted bool
}
type UsageBillingRepository interface {
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
}

View File

@@ -56,7 +56,8 @@ type cleanupRepoStub struct {
}
type dashboardRepoStub struct {
recomputeErr error
recomputeErr error
recomputeCalls int
}
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
}
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.recomputeErr
}
@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
return nil
}
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
}
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 0},
},
}
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
})
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
}
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
dashboardRepo := &dashboardRepoStub{}
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 0},
},
}
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
})
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
}
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {

View File

@@ -0,0 +1,82 @@
package service
import "errors"
type usageLogCreateDisposition int
const (
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
usageLogCreateDispositionNotPersisted
usageLogCreateDispositionDropped
)
type UsageLogCreateError struct {
err error
disposition usageLogCreateDisposition
}
func (e *UsageLogCreateError) Error() string {
if e == nil || e.err == nil {
return "usage log create error"
}
return e.err.Error()
}
func (e *UsageLogCreateError) Unwrap() error {
if e == nil {
return nil
}
return e.err
}
func MarkUsageLogCreateNotPersisted(err error) error {
if err == nil {
return nil
}
return &UsageLogCreateError{
err: err,
disposition: usageLogCreateDispositionNotPersisted,
}
}
func MarkUsageLogCreateDropped(err error) error {
if err == nil {
return nil
}
return &UsageLogCreateError{
err: err,
disposition: usageLogCreateDispositionDropped,
}
}
func IsUsageLogCreateNotPersisted(err error) bool {
if err == nil {
return false
}
var target *UsageLogCreateError
if !errors.As(err, &target) {
return false
}
return target.disposition == usageLogCreateDispositionNotPersisted
}
func IsUsageLogCreateDropped(err error) bool {
if err == nil {
return false
}
var target *UsageLogCreateError
if !errors.As(err, &target) {
return false
}
return target.disposition == usageLogCreateDispositionDropped
}
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
if inserted {
return true
}
if err == nil {
return false
}
return !IsUsageLogCreateNotPersisted(err)
}

View File

@@ -0,0 +1,13 @@
-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来
-- 幂等执行:可重复运行
CREATE TABLE IF NOT EXISTS usage_billing_dedup (
id BIGSERIAL PRIMARY KEY,
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key
ON usage_billing_dedup (request_id, api_key_id);

View File

@@ -0,0 +1,7 @@
-- usage_billing_dedup 是按时间追加写入的幂等窄表。
-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。
-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin
ON usage_billing_dedup
USING BRIN (created_at);

View File

@@ -0,0 +1,10 @@
-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。
CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive (
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (request_id, api_key_id)
);

View File

@@ -105,7 +105,7 @@ EXPOSE 8080
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
# Run the application
ENTRYPOINT ["/app/sub2api"]

0
deploy/build_image.sh Executable file → Normal file
View File

View File

@@ -154,7 +154,7 @@ services:
networks:
- sub2api-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3

View File

@@ -94,7 +94,7 @@ services:
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3

View File

@@ -146,7 +146,7 @@ services:
networks:
- sub2api-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3

0
deploy/install-datamanagementd.sh Executable file → Normal file
View File

View File

@@ -11,6 +11,7 @@ import type {
GroupStat,
ApiKeyUsageTrendPoint,
UserUsageTrendPoint,
UserSpendingRankingResponse,
UsageRequestType
} from '@/types'
@@ -201,6 +202,11 @@ export interface UserTrendResponse {
granularity: string
}
export interface UserSpendingRankingParams
extends Pick<TrendParams, 'start_date' | 'end_date'> {
limit?: number
}
/**
* Get user usage trend data
* @param params - Query parameters for filtering
@@ -213,6 +219,20 @@ export async function getUserUsageTrend(params?: UserTrendParams): Promise<UserT
return data
}
/**
* Get user spending ranking data
* @param params - Query parameters for filtering
* @returns User spending ranking data
*/
export async function getUserSpendingRanking(
params?: UserSpendingRankingParams
): Promise<UserSpendingRankingResponse> {
const { data } = await apiClient.get<UserSpendingRankingResponse>('/admin/dashboard/users-ranking', {
params
})
return data
}
export interface BatchUserUsageStats {
user_id: number
today_actual_cost: number
@@ -271,6 +291,7 @@ export const dashboardAPI = {
getSnapshotV2,
getApiKeyUsageTrend,
getUserUsageTrend,
getUserSpendingRanking,
getBatchUsersUsage,
getBatchApiKeysUsage
}

View File

@@ -232,7 +232,7 @@
<!-- Account Type Selection (Anthropic) -->
<div v-if="form.platform === 'anthropic'">
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
<div class="mt-2 grid grid-cols-2 gap-3" data-tour="account-form-type">
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
<button
type="button"
@click="accountCategory = 'oauth-based'"
@@ -292,6 +292,66 @@
}}</span>
</div>
</button>
<button
type="button"
@click="accountCategory = 'bedrock'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'bedrock'
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'bedrock'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
t('admin.accounts.bedrockLabel')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.bedrockDesc')
}}</span>
</div>
</button>
<button
type="button"
@click="accountCategory = 'bedrock-apikey'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'bedrock-apikey'
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'bedrock-apikey'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="key" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
t('admin.accounts.bedrockApiKeyLabel')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.bedrockApiKeyDesc')
}}</span>
</div>
</button>
</div>
</div>
@@ -896,7 +956,7 @@
</div>
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && accountCategory !== 'bedrock-apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
<input
@@ -1279,6 +1339,289 @@
</div>
<!-- Bedrock credentials (only for Anthropic Bedrock type) -->
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
<input
v-model="bedrockAccessKeyId"
type="text"
required
class="input font-mono"
placeholder="AKIA..."
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
<input
v-model="bedrockSecretAccessKey"
type="password"
required
class="input font-mono"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
<input
v-model="bedrockSessionToken"
type="password"
class="input font-mono"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<select v-model="bedrockRegion" class="input">
<optgroup label="US">
<option value="us-east-1">us-east-1 (N. Virginia)</option>
<option value="us-east-2">us-east-2 (Ohio)</option>
<option value="us-west-1">us-west-1 (N. California)</option>
<option value="us-west-2">us-west-2 (Oregon)</option>
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
</optgroup>
<optgroup label="Europe">
<option value="eu-west-1">eu-west-1 (Ireland)</option>
<option value="eu-west-2">eu-west-2 (London)</option>
<option value="eu-west-3">eu-west-3 (Paris)</option>
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
<option value="eu-central-2">eu-central-2 (Zurich)</option>
<option value="eu-south-1">eu-south-1 (Milan)</option>
<option value="eu-south-2">eu-south-2 (Spain)</option>
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
</optgroup>
<optgroup label="Asia Pacific">
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
</optgroup>
<optgroup label="Canada">
<option value="ca-central-1">ca-central-1 (Canada)</option>
</optgroup>
<optgroup label="South America">
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
</optgroup>
</select>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div>
<div>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="bedrockForceGlobal"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
</label>
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
</div>
<!-- Model Restriction Section for Bedrock -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<!-- Mode Toggle -->
<div class="mb-4 flex gap-2">
<button
type="button"
@click="modelRestrictionMode = 'whitelist'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'whitelist'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelWhitelist') }}
</button>
<button
type="button"
@click="modelRestrictionMode = 'mapping'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelMapping') }}
</button>
</div>
<!-- Whitelist Mode -->
<div v-if="modelRestrictionMode === 'whitelist'">
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
</p>
</div>
<!-- Mapping Mode -->
<div v-else class="space-y-3">
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2">
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
<span class="text-gray-400"></span>
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
<Icon name="trash" size="sm" />
</button>
</div>
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
+ {{ t('admin.accounts.addMapping') }}
</button>
<!-- Bedrock Preset Mappings -->
<div class="flex flex-wrap gap-2">
<button
v-for="preset in bedrockPresets"
:key="preset.from"
type="button"
@click="addPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
</div>
<!-- Bedrock API Key credentials (only for Anthropic Bedrock API Key type) -->
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock-apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
<input
v-model="bedrockApiKeyValue"
type="password"
required
class="input font-mono"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<select v-model="bedrockApiKeyRegion" class="input">
<optgroup label="US">
<option value="us-east-1">us-east-1 (N. Virginia)</option>
<option value="us-east-2">us-east-2 (Ohio)</option>
<option value="us-west-1">us-west-1 (N. California)</option>
<option value="us-west-2">us-west-2 (Oregon)</option>
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
</optgroup>
<optgroup label="Europe">
<option value="eu-west-1">eu-west-1 (Ireland)</option>
<option value="eu-west-2">eu-west-2 (London)</option>
<option value="eu-west-3">eu-west-3 (Paris)</option>
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
<option value="eu-central-2">eu-central-2 (Zurich)</option>
<option value="eu-south-1">eu-south-1 (Milan)</option>
<option value="eu-south-2">eu-south-2 (Spain)</option>
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
</optgroup>
<optgroup label="Asia Pacific">
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
</optgroup>
<optgroup label="Canada">
<option value="ca-central-1">ca-central-1 (Canada)</option>
</optgroup>
<optgroup label="South America">
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
</optgroup>
</select>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div>
<div>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="bedrockApiKeyForceGlobal"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
</label>
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
</div>
<!-- Model Restriction Section for Bedrock API Key -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<!-- Mode Toggle -->
<div class="mb-4 flex gap-2">
<button
type="button"
@click="modelRestrictionMode = 'whitelist'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'whitelist'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelWhitelist') }}
</button>
<button
type="button"
@click="modelRestrictionMode = 'mapping'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelMapping') }}
</button>
</div>
<!-- Whitelist Mode -->
<div v-if="modelRestrictionMode === 'whitelist'">
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
</p>
</div>
<!-- Mapping Mode -->
<div v-else class="space-y-3">
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2">
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
<span class="text-gray-400"></span>
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
<Icon name="trash" size="sm" />
</button>
</div>
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
+ {{ t('admin.accounts.addMapping') }}
</button>
<!-- Bedrock Preset Mappings -->
<div class="flex flex-wrap gap-2">
<button
v-for="preset in bedrockPresets"
:key="preset.from"
type="button"
@click="addPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
</div>
<!-- API Key 账号配额限制 -->
<div v-if="form.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
<div class="mb-3">
@@ -2671,7 +3014,7 @@ interface TempUnschedRuleForm {
// State
const step = ref(1)
const submitting = ref(false)
const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selection for account category
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
const apiKeyBaseUrl = ref('https://api.anthropic.com')
const apiKeyValue = ref('')
@@ -2704,6 +3047,19 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist'
const antigravityWhitelistModels = ref<string[]>([])
const antigravityModelMappings = ref<ModelMapping[]>([])
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
// Bedrock credentials
const bedrockAccessKeyId = ref('')
const bedrockSecretAccessKey = ref('')
const bedrockSessionToken = ref('')
const bedrockRegion = ref('us-east-1')
const bedrockForceGlobal = ref(false)
// Bedrock API Key credentials
const bedrockApiKeyValue = ref('')
const bedrockApiKeyRegion = ref('us-east-1')
const bedrockApiKeyForceGlobal = ref(false)
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
@@ -2868,6 +3224,10 @@ const isOAuthFlow = computed(() => {
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
return false
}
// Bedrock 类型不需要 OAuth 流程
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') {
return false
}
return accountCategory.value === 'oauth-based'
})
@@ -2935,6 +3295,11 @@ watch(
form.type = 'apikey'
return
}
// Bedrock 类型
if (form.platform === 'anthropic' && category === 'bedrock') {
form.type = 'bedrock' as AccountType
return
}
if (category === 'oauth-based') {
form.type = method as AccountType // 'oauth' or 'setup-token'
} else {
@@ -2972,6 +3337,13 @@ watch(
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
}
// Reset Bedrock fields when switching platforms
bedrockAccessKeyId.value = ''
bedrockSecretAccessKey.value = ''
bedrockSessionToken.value = ''
bedrockRegion.value = 'us-east-1'
bedrockForceGlobal.value = false
bedrockApiKeyForceGlobal.value = false
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false
@@ -3541,6 +3913,84 @@ const handleSubmit = async () => {
return
}
// For Bedrock type, create directly
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') {
if (!form.name.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
if (!bedrockAccessKeyId.value.trim()) {
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
return
}
if (!bedrockSecretAccessKey.value.trim()) {
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
return
}
if (!bedrockRegion.value.trim()) {
appStore.showError(t('admin.accounts.bedrockRegionRequired'))
return
}
const credentials: Record<string, unknown> = {
aws_access_key_id: bedrockAccessKeyId.value.trim(),
aws_secret_access_key: bedrockSecretAccessKey.value.trim(),
aws_region: bedrockRegion.value.trim(),
}
if (bedrockSessionToken.value.trim()) {
credentials.aws_session_token = bedrockSessionToken.value.trim()
}
if (bedrockForceGlobal.value) {
credentials.aws_force_global = 'true'
}
// Model mapping
const modelMapping = buildModelMappingObject(
modelRestrictionMode.value, allowedModels.value, modelMappings.value
)
if (modelMapping) {
credentials.model_mapping = modelMapping
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials)
return
}
// For Bedrock API Key type, create directly
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') {
if (!form.name.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
if (!bedrockApiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
return
}
const credentials: Record<string, unknown> = {
api_key: bedrockApiKeyValue.value.trim(),
aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1',
}
if (bedrockApiKeyForceGlobal.value) {
credentials.aws_force_global = 'true'
}
// Model mapping
const modelMapping = buildModelMappingObject(
modelRestrictionMode.value, allowedModels.value, modelMappings.value
)
if (modelMapping) {
credentials.model_mapping = modelMapping
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials)
return
}
// For Antigravity upstream type, create directly
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
if (!form.name.trim()) {

View File

@@ -563,6 +563,233 @@
</div>
</div>
<!-- Bedrock fields (only for bedrock type) -->
<div v-if="account.type === 'bedrock'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
<input
v-model="editBedrockAccessKeyId"
type="text"
class="input font-mono"
placeholder="AKIA..."
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
<input
v-model="editBedrockSecretAccessKey"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
<input
v-model="editBedrockSessionToken"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<input
v-model="editBedrockRegion"
type="text"
class="input"
placeholder="us-east-1"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div>
<div>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="editBedrockForceGlobal"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
</label>
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
</div>
<!-- Model Restriction for Bedrock -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<!-- Mode Toggle -->
<div class="mb-4 flex gap-2">
<button
type="button"
@click="modelRestrictionMode = 'whitelist'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'whitelist'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelWhitelist') }}
</button>
<button
type="button"
@click="modelRestrictionMode = 'mapping'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelMapping') }}
</button>
</div>
<!-- Whitelist Mode -->
<div v-if="modelRestrictionMode === 'whitelist'">
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
</p>
</div>
<!-- Mapping Mode -->
<div v-else class="space-y-3">
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2">
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
<span class="text-gray-400"></span>
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
<Icon name="trash" size="sm" />
</button>
</div>
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
+ {{ t('admin.accounts.addMapping') }}
</button>
<!-- Bedrock Preset Mappings -->
<div class="flex flex-wrap gap-2">
<button
v-for="preset in bedrockPresets"
:key="preset.from"
type="button"
@click="modelMappings.push({ from: preset.from, to: preset.to })"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
</div>
<!-- Bedrock API Key fields (only for bedrock-apikey type) -->
<div v-if="account.type === 'bedrock-apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
<input
v-model="editBedrockApiKeyValue"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<input
v-model="editBedrockApiKeyRegion"
type="text"
class="input"
placeholder="us-east-1"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div>
<div>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="editBedrockApiKeyForceGlobal"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
</label>
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
</div>
<!-- Model Restriction for Bedrock API Key -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<!-- Mode Toggle -->
<div class="mb-4 flex gap-2">
<button
type="button"
@click="modelRestrictionMode = 'whitelist'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'whitelist'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelWhitelist') }}
</button>
<button
type="button"
@click="modelRestrictionMode = 'mapping'"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelMapping') }}
</button>
</div>
<!-- Whitelist Mode -->
<div v-if="modelRestrictionMode === 'whitelist'">
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
</p>
</div>
<!-- Mapping Mode -->
<div v-else class="space-y-3">
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2">
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
<span class="text-gray-400"></span>
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
<Icon name="trash" size="sm" />
</button>
</div>
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
+ {{ t('admin.accounts.addMapping') }}
</button>
<!-- Bedrock Preset Mappings -->
<div class="flex flex-wrap gap-2">
<button
v-for="preset in bedrockPresets"
:key="preset.from"
type="button"
@click="modelMappings.push({ from: preset.from, to: preset.to })"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
</div>
<!-- Antigravity model restriction (applies to all antigravity types) -->
<!-- Antigravity 只支持模型映射模式不支持白名单模式 -->
<div v-if="account.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
@@ -1529,6 +1756,7 @@ const baseUrlHint = computed(() => {
})
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
// Model mapping type
interface ModelMapping {
@@ -1547,6 +1775,17 @@ interface TempUnschedRuleForm {
const submitting = ref(false)
const editBaseUrl = ref('https://api.anthropic.com')
const editApiKey = ref('')
// Bedrock credentials
const editBedrockAccessKeyId = ref('')
const editBedrockSecretAccessKey = ref('')
const editBedrockSessionToken = ref('')
const editBedrockRegion = ref('')
const editBedrockForceGlobal = ref(false)
// Bedrock API Key credentials
const editBedrockApiKeyValue = ref('')
const editBedrockApiKeyRegion = ref('')
const editBedrockApiKeyForceGlobal = ref(false)
const modelMappings = ref<ModelMapping[]>([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([])
@@ -1889,6 +2128,58 @@ watch(
} else {
selectedErrorCodes.value = []
}
} else if (newAccount.type === 'bedrock' && newAccount.credentials) {
const bedrockCreds = newAccount.credentials as Record<string, unknown>
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
editBedrockRegion.value = (bedrockCreds.aws_region as string) || ''
editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true'
editBedrockSecretAccessKey.value = ''
editBedrockSessionToken.value = ''
// Load model mappings for bedrock
const existingMappings = bedrockCreds.model_mapping as Record<string, string> | undefined
if (existingMappings && typeof existingMappings === 'object') {
const entries = Object.entries(existingMappings)
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
if (isWhitelistMode) {
modelRestrictionMode.value = 'whitelist'
allowedModels.value = entries.map(([from]) => from)
modelMappings.value = []
} else {
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
} else {
modelRestrictionMode.value = 'whitelist'
modelMappings.value = []
allowedModels.value = []
}
} else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) {
const bedrockApiKeyCreds = newAccount.credentials as Record<string, unknown>
editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1'
editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true'
editBedrockApiKeyValue.value = ''
// Load model mappings for bedrock-apikey
const existingMappings = bedrockApiKeyCreds.model_mapping as Record<string, string> | undefined
if (existingMappings && typeof existingMappings === 'object') {
const entries = Object.entries(existingMappings)
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
if (isWhitelistMode) {
modelRestrictionMode.value = 'whitelist'
allowedModels.value = entries.map(([from]) => from)
modelMappings.value = []
} else {
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
} else {
modelRestrictionMode.value = 'whitelist'
modelMappings.value = []
allowedModels.value = []
}
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown>
editBaseUrl.value = (credentials.base_url as string) || ''
@@ -2431,6 +2722,70 @@ const handleSubmit = async () => {
return
}
updatePayload.credentials = newCredentials
} else if (props.account.type === 'bedrock') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
newCredentials.aws_region = editBedrockRegion.value.trim()
if (editBedrockForceGlobal.value) {
newCredentials.aws_force_global = 'true'
} else {
delete newCredentials.aws_force_global
}
// Only update secrets if user provided new values
if (editBedrockSecretAccessKey.value.trim()) {
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
}
if (editBedrockSessionToken.value.trim()) {
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
}
// Model mapping
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
delete newCredentials.model_mapping
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
return
}
updatePayload.credentials = newCredentials
} else if (props.account.type === 'bedrock-apikey') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1'
if (editBedrockApiKeyForceGlobal.value) {
newCredentials.aws_force_global = 'true'
} else {
delete newCredentials.aws_force_global
}
// Only update API key if user provided new value
if (editBedrockApiKeyValue.value.trim()) {
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
}
// Model mapping
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
delete newCredentials.model_mapping
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
return
}
updatePayload.credentials = newCredentials
} else {
// For oauth/setup-token types, only update intercept_warmup_requests if changed

View File

@@ -2,38 +2,72 @@
<div class="card p-4">
<div class="mb-4 flex items-center justify-between gap-3">
<h3 class="text-sm font-semibold text-gray-900 dark:text-white">
{{ t('admin.dashboard.modelDistribution') }}
{{ !enableRankingView || activeView === 'model_distribution'
? t('admin.dashboard.modelDistribution')
: t('admin.dashboard.spendingRankingTitle') }}
</h3>
<div
v-if="showMetricToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="metric === 'tokens'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:metric', 'tokens')"
<div class="flex items-center gap-2">
<div
v-if="showMetricToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
>
{{ t('admin.dashboard.metricTokens') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="metric === 'actual_cost'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:metric', 'actual_cost')"
>
{{ t('admin.dashboard.metricActualCost') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="metric === 'tokens'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:metric', 'tokens')"
>
{{ t('admin.dashboard.metricTokens') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="metric === 'actual_cost'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:metric', 'actual_cost')"
>
{{ t('admin.dashboard.metricActualCost') }}
</button>
</div>
<div v-if="enableRankingView" class="inline-flex rounded-lg bg-gray-100 p-1 dark:bg-dark-800">
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="
activeView === 'model_distribution'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'
"
@click="activeView = 'model_distribution'"
>
{{ t('admin.dashboard.viewModelDistribution') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="
activeView === 'spending_ranking'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'
"
@click="activeView = 'spending_ranking'"
>
{{ t('admin.dashboard.viewSpendingRanking') }}
</button>
</div>
</div>
</div>
<div v-if="loading" class="flex h-48 items-center justify-center">
<div v-if="activeView === 'model_distribution' && loading" class="flex h-48 items-center justify-center">
<LoadingSpinner />
</div>
<div v-else-if="displayModelStats.length > 0 && chartData" class="flex items-center gap-6">
<div
v-else-if="activeView === 'model_distribution' && displayModelStats.length > 0 && chartData"
class="flex items-center gap-6"
>
<div class="h-48 w-48">
<Doughnut :data="chartData" :options="doughnutOptions" />
</div>
@@ -77,6 +111,70 @@
</table>
</div>
</div>
<div
v-else-if="activeView === 'model_distribution'"
class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400"
>
{{ t('admin.dashboard.noDataAvailable') }}
</div>
<div v-else-if="rankingLoading" class="flex h-48 items-center justify-center">
<LoadingSpinner />
</div>
<div
v-else-if="rankingError"
class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400"
>
{{ t('admin.dashboard.failedToLoad') }}
</div>
<div v-else-if="rankingItems.length > 0 && rankingChartData" class="flex items-center gap-6">
<div class="h-48 w-48">
<Doughnut :data="rankingChartData" :options="rankingDoughnutOptions" />
</div>
<div class="max-h-48 flex-1 overflow-y-auto">
<table class="w-full text-xs">
<thead>
<tr class="text-gray-500 dark:text-gray-400">
<th class="pb-2 text-left">{{ t('admin.dashboard.spendingRankingUser') }}</th>
<th class="pb-2 text-right">{{ t('admin.dashboard.spendingRankingRequests') }}</th>
<th class="pb-2 text-right">{{ t('admin.dashboard.spendingRankingTokens') }}</th>
<th class="pb-2 text-right">{{ t('admin.dashboard.spendingRankingSpend') }}</th>
</tr>
</thead>
<tbody>
<tr
v-for="(item, index) in rankingItems"
:key="`${item.user_id}-${index}`"
class="cursor-pointer border-t border-gray-100 transition-colors hover:bg-gray-50 dark:border-gray-700 dark:hover:bg-dark-700/40"
@click="emit('ranking-click', item)"
>
<td class="py-1.5">
<div class="flex min-w-0 items-center gap-2">
<span class="shrink-0 text-[11px] font-semibold text-gray-500 dark:text-gray-400">
#{{ index + 1 }}
</span>
<span
class="block max-w-[140px] truncate font-medium text-gray-900 dark:text-white"
:title="getRankingUserLabel(item)"
>
{{ getRankingUserLabel(item) }}
</span>
</div>
</td>
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">
{{ formatNumber(item.requests) }}
</td>
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">
{{ formatTokens(item.tokens) }}
</td>
<td class="py-1.5 text-right text-green-600 dark:text-green-400">
${{ formatCost(item.actual_cost) }}
</td>
</tr>
</tbody>
</table>
</div>
</div>
<div
v-else
class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400"
@@ -87,34 +185,47 @@
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { computed, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { Chart as ChartJS, ArcElement, Tooltip, Legend } from 'chart.js'
import { Doughnut } from 'vue-chartjs'
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
import type { ModelStat } from '@/types'
import type { ModelStat, UserSpendingRankingItem } from '@/types'
ChartJS.register(ArcElement, Tooltip, Legend)
const { t } = useI18n()
type DistributionMetric = 'tokens' | 'actual_cost'
const props = withDefaults(defineProps<{
modelStats: ModelStat[]
enableRankingView?: boolean
rankingItems?: UserSpendingRankingItem[]
rankingTotalActualCost?: number
loading?: boolean
metric?: DistributionMetric
showMetricToggle?: boolean
rankingLoading?: boolean
rankingError?: boolean
}>(), {
enableRankingView: false,
rankingItems: () => [],
rankingTotalActualCost: 0,
loading: false,
metric: 'tokens',
showMetricToggle: false,
rankingLoading: false,
rankingError: false
})
const emit = defineEmits<{
'update:metric': [value: DistributionMetric]
'ranking-click': [item: UserSpendingRankingItem]
}>()
const enableRankingView = computed(() => props.enableRankingView)
const activeView = ref<'model_distribution' | 'spending_ranking'>('model_distribution')
const chartColors = [
'#3b82f6',
'#10b981',
@@ -125,7 +236,9 @@ const chartColors = [
'#14b8a6',
'#f97316',
'#6366f1',
'#84cc16'
'#84cc16',
'#06b6d4',
'#a855f7'
]
const displayModelStats = computed(() => {
@@ -150,6 +263,31 @@ const chartData = computed(() => {
}
})
const rankingChartData = computed(() => {
if (!props.rankingItems?.length) return null
const rankedTotal = props.rankingItems.reduce((sum, item) => sum + item.actual_cost, 0)
const otherActualCost = Math.max((props.rankingTotalActualCost || 0) - rankedTotal, 0)
const labels = props.rankingItems.map((item, index) => `#${index + 1} ${getRankingUserLabel(item)}`)
const data = props.rankingItems.map((item) => item.actual_cost)
if (otherActualCost > 0.000001) {
labels.push(t('admin.dashboard.spendingRankingOther'))
data.push(otherActualCost)
}
return {
labels,
datasets: [
{
data,
backgroundColor: chartColors.slice(0, data.length),
borderWidth: 0
}
]
}
})
const doughnutOptions = computed(() => ({
responsive: true,
maintainAspectRatio: false,
@@ -173,6 +311,26 @@ const doughnutOptions = computed(() => ({
}
}))
const rankingDoughnutOptions = computed(() => ({
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: {
display: false
},
tooltip: {
callbacks: {
label: (context: any) => {
const value = context.raw as number
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
const percentage = total > 0 ? ((value / total) * 100).toFixed(1) : '0.0'
return `${context.label}: $${formatCost(value)} (${percentage}%)`
}
}
}
}
}))
const formatTokens = (value: number): string => {
if (value >= 1_000_000_000) {
return `${(value / 1_000_000_000).toFixed(2)}B`
@@ -188,6 +346,11 @@ const formatNumber = (value: number): string => {
return value.toLocaleString()
}
const getRankingUserLabel = (item: UserSpendingRankingItem): string => {
if (item.email) return item.email
return t('admin.redeem.userPrefix', { id: item.user_id })
}
const formatCost = (value: number): string => {
if (value >= 1000) {
return (value / 1000).toFixed(2) + 'K'

View File

@@ -82,6 +82,8 @@ const typeLabel = computed(() => {
return 'Token'
case 'apikey':
return 'Key'
case 'bedrock':
return 'Bedrock'
default:
return props.type
}

View File

@@ -331,6 +331,15 @@ const antigravityPresetMappings = [
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
]
// Bedrock 预设映射(与后端 DefaultBedrockModelMapping 保持一致)
const bedrockPresetMappings = [
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Haiku 4.5', from: 'claude-haiku-4-5', to: 'us.anthropic.claude-haiku-4-5-20251001-v1:0', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
]
// Antigravity 默认映射(从后端 API 获取,与 constants.go 保持一致)
// 使用 fetchAntigravityDefaultMappings() 异步获取
import { getAntigravityDefaultModelMapping } from '@/api/admin/accounts'
@@ -403,6 +412,7 @@ export function getPresetMappingsByPlatform(platform: string) {
if (platform === 'gemini') return geminiPresetMappings
if (platform === 'sora') return soraPresetMappings
if (platform === 'antigravity') return antigravityPresetMappings
if (platform === 'bedrock' || platform === 'bedrock-apikey') return bedrockPresetMappings
return anthropicPresetMappings
}

View File

@@ -963,6 +963,18 @@ export default {
standard: 'Standard',
noDataAvailable: 'No data available',
recentUsage: 'Recent Usage',
viewModelDistribution: 'Model Distribution',
viewSpendingRanking: 'User Spending Ranking',
spendingRankingTitle: 'User Spending Ranking',
spendingRankingUser: 'User',
spendingRankingRequests: 'Requests',
spendingRankingTokens: 'Tokens',
spendingRankingSpend: 'Spend',
spendingRankingOther: 'Others',
spendingRankingUsage: 'Usage',
spendShort: 'Spend',
requestsShort: 'Req',
tokensShort: 'Tok',
failedToLoad: 'Failed to load dashboard statistics'
},
@@ -1921,6 +1933,8 @@ export default {
accountType: 'Account Type',
claudeCode: 'Claude Code',
claudeConsole: 'Claude Console',
bedrockLabel: 'AWS Bedrock',
bedrockDesc: 'SigV4 Signing',
oauthSetupToken: 'OAuth / Setup Token',
addMethod: 'Add Method',
setupTokenLongLived: 'Setup Token (Long-lived)',
@@ -2110,6 +2124,23 @@ export default {
mixedChannelWarning: 'Warning: Group "{groupName}" contains both {currentPlatform} and {otherPlatform} accounts. Mixing different channels may cause thinking block signature validation issues, which will fallback to non-thinking mode. Are you sure you want to continue?',
pleaseEnterAccountName: 'Please enter account name',
pleaseEnterApiKey: 'Please enter API Key',
bedrockAccessKeyId: 'AWS Access Key ID',
bedrockSecretAccessKey: 'AWS Secret Access Key',
bedrockSessionToken: 'AWS Session Token',
bedrockRegion: 'AWS Region',
bedrockRegionHint: 'e.g. us-east-1, us-west-2, eu-west-1',
bedrockForceGlobal: 'Force Global cross-region inference',
bedrockForceGlobalHint: 'When enabled, model IDs use the global. prefix (e.g. global.anthropic.claude-...), routing requests to any supported region worldwide for higher availability',
bedrockAccessKeyIdRequired: 'Please enter AWS Access Key ID',
bedrockSecretAccessKeyRequired: 'Please enter AWS Secret Access Key',
bedrockRegionRequired: 'Please select AWS Region',
bedrockSessionTokenHint: 'Optional, for temporary credentials',
bedrockSecretKeyLeaveEmpty: 'Leave empty to keep current key',
bedrockApiKeyLabel: 'Bedrock API Key',
bedrockApiKeyDesc: 'Bearer Token',
bedrockApiKeyInput: 'API Key',
bedrockApiKeyRequired: 'Please enter Bedrock API Key',
bedrockApiKeyLeaveEmpty: 'Leave empty to keep current key',
apiKeyIsRequired: 'API Key is required',
leaveEmptyToKeep: 'Leave empty to keep current key',
// Upstream type

View File

@@ -974,6 +974,18 @@ export default {
tokens: 'Token',
cache: '缓存',
recentUsage: '最近使用',
viewModelDistribution: '模型分布',
viewSpendingRanking: '用户消费榜',
spendingRankingTitle: '用户消费榜',
spendingRankingUser: '用户',
spendingRankingRequests: '请求',
spendingRankingTokens: 'Token',
spendingRankingSpend: '消费',
spendingRankingOther: '其他',
spendingRankingUsage: '用量',
spendShort: '消费',
requestsShort: '请求',
tokensShort: 'Token',
last7Days: '近 7 天',
noUsageRecords: '暂无使用记录',
startUsingApi: '开始使用 API 后,使用历史将显示在这里。',
@@ -2069,6 +2081,8 @@ export default {
accountType: '账号类型',
claudeCode: 'Claude Code',
claudeConsole: 'Claude Console',
bedrockLabel: 'AWS Bedrock',
bedrockDesc: 'SigV4 签名',
oauthSetupToken: 'OAuth / Setup Token',
addMethod: '添加方式',
setupTokenLongLived: 'Setup Token长期有效',
@@ -2251,6 +2265,23 @@ export default {
mixedChannelWarning: '警告:分组 "{groupName}" 中同时包含 {currentPlatform} 和 {otherPlatform} 账号。混合使用不同渠道可能导致 thinking block 签名验证问题,会自动回退到非 thinking 模式。确定要继续吗?',
pleaseEnterAccountName: '请输入账号名称',
pleaseEnterApiKey: '请输入 API Key',
bedrockAccessKeyId: 'AWS Access Key ID',
bedrockSecretAccessKey: 'AWS Secret Access Key',
bedrockSessionToken: 'AWS Session Token',
bedrockRegion: 'AWS Region',
bedrockRegionHint: '例如 us-east-1, us-west-2, eu-west-1',
bedrockForceGlobal: '强制使用 Global 跨区域推理',
bedrockForceGlobalHint: '启用后模型 ID 使用 global. 前缀(如 global.anthropic.claude-...),请求可路由到全球任意支持的区域,获得更高可用性',
bedrockAccessKeyIdRequired: '请输入 AWS Access Key ID',
bedrockSecretAccessKeyRequired: '请输入 AWS Secret Access Key',
bedrockRegionRequired: '请选择 AWS Region',
bedrockSessionTokenHint: '可选,用于临时凭证',
bedrockSecretKeyLeaveEmpty: '留空以保持当前密钥',
bedrockApiKeyLabel: 'Bedrock API Key',
bedrockApiKeyDesc: 'Bearer Token 认证',
bedrockApiKeyInput: 'API Key',
bedrockApiKeyRequired: '请输入 Bedrock API Key',
bedrockApiKeyLeaveEmpty: '留空以保持当前密钥',
apiKeyIsRequired: 'API Key 是必需的',
leaveEmptyToKeep: '留空以保持当前密钥',
// Upstream type

View File

@@ -531,7 +531,7 @@ export interface UpdateGroupRequest {
// ==================== Account & Proxy Types ====================
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora'
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream'
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'bedrock-apikey'
export type OAuthAddMethod = 'oauth' | 'setup-token'
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
@@ -1155,12 +1155,28 @@ export interface UserUsageTrendPoint {
date: string
user_id: number
email: string
username: string
requests: number
tokens: number
cost: number // 标准计费
actual_cost: number // 实际扣除
}
export interface UserSpendingRankingItem {
user_id: number
email: string
actual_cost: number
requests: number
tokens: number
}
export interface UserSpendingRankingResponse {
ranking: UserSpendingRankingItem[]
total_actual_cost: number
start_date: string
end_date: string
}
export interface ApiKeyUsageTrendPoint {
date: string
api_key_id: number

View File

@@ -236,7 +236,16 @@
<!-- Charts Grid -->
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
<ModelDistributionChart :model-stats="modelStats" :loading="chartsLoading" />
<ModelDistributionChart
:model-stats="modelStats"
:enable-ranking-view="true"
:ranking-items="rankingItems"
:ranking-total-actual-cost="rankingTotalActualCost"
:loading="chartsLoading"
:ranking-loading="rankingLoading"
:ranking-error="rankingError"
@ranking-click="goToUserUsage"
/>
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
</div>
@@ -267,11 +276,18 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRouter } from 'vue-router'
import { useAppStore } from '@/stores/app'
const { t } = useI18n()
import { adminAPI } from '@/api/admin'
import type { DashboardStats, TrendDataPoint, ModelStat, UserUsageTrendPoint } from '@/types'
import type {
DashboardStats,
TrendDataPoint,
ModelStat,
UserUsageTrendPoint,
UserSpendingRankingItem
} from '@/types'
import AppLayout from '@/components/layout/AppLayout.vue'
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
import Icon from '@/components/icons/Icon.vue'
@@ -286,7 +302,6 @@ import {
LinearScale,
PointElement,
LineElement,
Title,
Tooltip,
Legend,
Filler
@@ -299,39 +314,42 @@ ChartJS.register(
LinearScale,
PointElement,
LineElement,
Title,
Tooltip,
Legend,
Filler
)
const appStore = useAppStore()
const router = useRouter()
const stats = ref<DashboardStats | null>(null)
const loading = ref(false)
const chartsLoading = ref(false)
const userTrendLoading = ref(false)
const rankingLoading = ref(false)
const rankingError = ref(false)
// Chart data
const trendData = ref<TrendDataPoint[]>([])
const modelStats = ref<ModelStat[]>([])
const userTrend = ref<UserUsageTrendPoint[]>([])
const rankingItems = ref<UserSpendingRankingItem[]>([])
const rankingTotalActualCost = ref(0)
let chartLoadSeq = 0
let usersTrendLoadSeq = 0
let rankingLoadSeq = 0
const rankingLimit = 12
// Helper function to format date in local timezone
const formatLocalDate = (date: Date): string => {
return `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}-${String(date.getDate()).padStart(2, '0')}`
}
// Initialize date range immediately
const now = new Date()
const weekAgo = new Date(now)
weekAgo.setDate(weekAgo.getDate() - 6)
const getTodayLocalDate = () => formatLocalDate(new Date())
// Date range
const granularity = ref<'day' | 'hour'>('day')
const startDate = ref(formatLocalDate(weekAgo))
const endDate = ref(formatLocalDate(now))
const startDate = ref(getTodayLocalDate())
const endDate = ref(getTodayLocalDate())
// Granularity options for Select component
const granularityOptions = computed(() => [
@@ -415,23 +433,29 @@ const lineOptions = computed(() => ({
const userTrendChartData = computed(() => {
if (!userTrend.value?.length) return null
// Extract display name from email (part before @)
const getDisplayName = (email: string, userId: number): string => {
if (email && email.includes('@')) {
return email.split('@')[0]
const getDisplayName = (point: UserUsageTrendPoint): string => {
const username = point.username?.trim()
if (username) {
return username
}
return t('admin.redeem.userPrefix', { id: userId })
const email = point.email?.trim()
if (email) {
return email
}
return t('admin.redeem.userPrefix', { id: point.user_id })
}
// Group by user
const userGroups = new Map<string, { name: string; data: Map<string, number> }>()
// Group by user_id to avoid merging different users with the same display name
const userGroups = new Map<number, { name: string; data: Map<string, number> }>()
const allDates = new Set<string>()
userTrend.value.forEach((point) => {
allDates.add(point.date)
const key = getDisplayName(point.email, point.user_id)
const key = point.user_id
if (!userGroups.has(key)) {
userGroups.set(key, { name: key, data: new Map() })
userGroups.set(key, { name: getDisplayName(point), data: new Map() })
}
userGroups.get(key)!.data.set(point.date, point.tokens)
})
@@ -502,6 +526,17 @@ const formatDuration = (ms: number): string => {
return `${Math.round(ms)}ms`
}
const goToUserUsage = (item: UserSpendingRankingItem) => {
void router.push({
path: '/admin/usage',
query: {
user_id: String(item.user_id),
start_date: startDate.value,
end_date: endDate.value
}
})
}
// Date range change handler
const onDateRangeChange = (range: {
startDate: string
@@ -582,14 +617,46 @@ const loadUsersTrend = async () => {
}
}
const loadUserSpendingRanking = async () => {
const currentSeq = ++rankingLoadSeq
rankingLoading.value = true
rankingError.value = false
try {
const response = await adminAPI.dashboard.getUserSpendingRanking({
start_date: startDate.value,
end_date: endDate.value,
limit: rankingLimit
})
if (currentSeq !== rankingLoadSeq) return
rankingItems.value = response.ranking || []
rankingTotalActualCost.value = response.total_actual_cost || 0
} catch (error) {
if (currentSeq !== rankingLoadSeq) return
console.error('Error loading user spending ranking:', error)
rankingItems.value = []
rankingTotalActualCost.value = 0
rankingError.value = true
} finally {
if (currentSeq === rankingLoadSeq) {
rankingLoading.value = false
}
}
}
const loadDashboardStats = async () => {
await loadDashboardSnapshot(true)
void loadUsersTrend()
await Promise.all([
loadDashboardSnapshot(true),
loadUsersTrend(),
loadUserSpendingRanking()
])
}
const loadChartData = async () => {
await loadDashboardSnapshot(false)
void loadUsersTrend()
await Promise.all([
loadDashboardSnapshot(false),
loadUsersTrend(),
loadUserSpendingRanking()
])
}
onMounted(() => {

View File

@@ -89,6 +89,7 @@
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { saveAs } from 'file-saver'
import { useRoute } from 'vue-router'
import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage'
import { formatReasoningEffort } from '@/utils/format'
import { resolveUsageRequestType, requestTypeToLegacyStream } from '@/utils/usageRequestType'
@@ -104,7 +105,7 @@ import type { AdminUsageLog, TrendDataPoint, ModelStat, GroupStat, AdminUser } f
const { t } = useI18n()
const appStore = useAppStore()
type DistributionMetric = 'tokens' | 'actual_cost'
const route = useRoute()
const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs = ref<AdminUsageLog[]>([]); const loading = ref(false); const exporting = ref(false)
const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const groupStats = ref<GroupStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
const modelDistributionMetric = ref<DistributionMetric>('tokens')
@@ -135,11 +136,43 @@ const formatLD = (d: Date) => {
const day = String(d.getDate()).padStart(2, '0')
return `${year}-${month}-${day}`
}
const now = new Date(); const weekAgo = new Date(); weekAgo.setDate(weekAgo.getDate() - 6)
const startDate = ref(formatLD(weekAgo)); const endDate = ref(formatLD(now))
const getTodayLocalDate = () => formatLD(new Date())
const startDate = ref(getTodayLocalDate()); const endDate = ref(getTodayLocalDate())
const filters = ref<AdminUsageQueryParams>({ user_id: undefined, model: undefined, group_id: undefined, request_type: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
const getSingleQueryValue = (value: string | null | Array<string | null> | undefined): string | undefined => {
if (Array.isArray(value)) return value.find((item): item is string => typeof item === 'string' && item.length > 0)
return typeof value === 'string' && value.length > 0 ? value : undefined
}
const getNumericQueryValue = (value: string | null | Array<string | null> | undefined): number | undefined => {
const raw = getSingleQueryValue(value)
if (!raw) return undefined
const parsed = Number(raw)
return Number.isFinite(parsed) ? parsed : undefined
}
const applyRouteQueryFilters = () => {
const queryStartDate = getSingleQueryValue(route.query.start_date)
const queryEndDate = getSingleQueryValue(route.query.end_date)
const queryUserId = getNumericQueryValue(route.query.user_id)
if (queryStartDate) {
startDate.value = queryStartDate
}
if (queryEndDate) {
endDate.value = queryEndDate
}
filters.value = {
...filters.value,
user_id: queryUserId,
start_date: startDate.value,
end_date: endDate.value
}
}
const loadLogs = async () => {
abortController?.abort(); const c = new AbortController(); abortController = c; loading.value = true
try {
@@ -191,7 +224,7 @@ const loadChartData = async () => {
}
const applyFilters = () => { pagination.page = 1; loadLogs(); loadStats(); loadChartData() }
const refreshData = () => { loadLogs(); loadStats(); loadChartData() }
const resetFilters = () => { startDate.value = formatLD(weekAgo); endDate.value = formatLD(now); filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }; granularity.value = 'day'; applyFilters() }
const resetFilters = () => { startDate.value = getTodayLocalDate(); endDate.value = getTodayLocalDate(); filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }; granularity.value = 'day'; applyFilters() }
const handlePageChange = (p: number) => { pagination.page = p; loadLogs() }
const handlePageSizeChange = (s: number) => { pagination.page_size = s; pagination.page = 1; loadLogs() }
const cancelExport = () => exportAbortController?.abort()
@@ -329,6 +362,7 @@ const handleColumnClickOutside = (event: MouseEvent) => {
}
onMounted(() => {
applyRouteQueryFilters()
loadLogs()
loadStats()
window.setTimeout(() => {