mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
feat: decouple billing correctness from usage log batching
This commit is contained in:
307
PR_REPORT_20260311_db_write_hotspots.md
Normal file
307
PR_REPORT_20260311_db_write_hotspots.md
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
# PR Report: DB 写入热点与后台查询拥塞排查
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
线上在高峰期出现了几类明显症状:
|
||||||
|
|
||||||
|
- 管理后台仪表盘接口经常超时,`/api/v1/admin/dashboard/snapshot-v2` 一度达到 50s 以上
|
||||||
|
- 管理后台充值接口 `/api/v1/admin/users/:id/balance` 出现 15s 以上超时
|
||||||
|
- 登录态刷新、扣费、错误记录在高峰期出现大量 `context deadline exceeded`
|
||||||
|
- PostgreSQL 曾出现连接打满,后续回退连接池后,主问题转为 WAL/刷盘拥塞
|
||||||
|
|
||||||
|
本报告基于 `/home/ius/sub2api` 当前源码,目标是给出一份可直接拆成 PR 的修复方案。
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
这次故障的主因不是单一“慢 SQL”,而是请求成功路径上的同步写库次数过多,叠加部分后台查询仍直接扫 `usage_logs`,最终把 PostgreSQL 的 WAL 刷盘、热点行更新和 outbox 重建链路一起放大。
|
||||||
|
|
||||||
|
代码层面的核心问题有 6 个。
|
||||||
|
|
||||||
|
### 1. 成功请求路径同步写库过多
|
||||||
|
|
||||||
|
`backend/internal/service/gateway_service.go:6594` 的 `postUsageBilling` 在单次请求成功后,可能同步触发以下写操作:
|
||||||
|
|
||||||
|
- `userRepo.DeductBalance`
|
||||||
|
- `APIKeyService.UpdateQuotaUsed`
|
||||||
|
- `APIKeyService.UpdateRateLimitUsage`
|
||||||
|
- `accountRepo.IncrementQuotaUsed`
|
||||||
|
- `deferredService.ScheduleLastUsedUpdate`(这一项已经做了延迟批量,是正确方向)
|
||||||
|
|
||||||
|
也就是说,一次成功请求不是 1 次落库,而是 3 到 5 次写入。
|
||||||
|
|
||||||
|
这和线上看到的现象是吻合的:
|
||||||
|
|
||||||
|
- `UPDATE accounts SET extra = ...`
|
||||||
|
- `INSERT INTO usage_logs ...`
|
||||||
|
- `INSERT INTO ops_error_logs ...`
|
||||||
|
- `scheduler_outbox` backlog
|
||||||
|
|
||||||
|
### 2. API Key 配额更新存在额外读写放大
|
||||||
|
|
||||||
|
`backend/internal/service/api_key_service.go:815` 的 `UpdateQuotaUsed` 当前流程是:
|
||||||
|
|
||||||
|
1. `IncrementQuotaUsed`
|
||||||
|
2. `GetByID`
|
||||||
|
3. 如超限再 `Update`
|
||||||
|
|
||||||
|
对应仓储实现:
|
||||||
|
|
||||||
|
- `backend/internal/repository/api_key_repo.go:441` 只做自增
|
||||||
|
- 然后 service 再回表读取完整 API Key
|
||||||
|
- 之后可能再整行更新状态
|
||||||
|
|
||||||
|
这让“每次扣费后更新 API Key 配额”从 1 条 SQL 变成了最多 3 次数据库交互。
|
||||||
|
|
||||||
|
### 3. `accounts.extra` 被当成高频热写字段使用
|
||||||
|
|
||||||
|
两个最重的热点都落在 `accounts.extra`:
|
||||||
|
|
||||||
|
- `backend/internal/repository/account_repo.go:1159` `UpdateExtra`
|
||||||
|
- `backend/internal/repository/account_repo.go:1683` `IncrementQuotaUsed`
|
||||||
|
|
||||||
|
问题有两个:
|
||||||
|
|
||||||
|
1. 两者都会重写整块 JSONB,并更新 `updated_at`
|
||||||
|
2. `UpdateExtra` 每次写完都会额外插入一条 `scheduler_outbox`
|
||||||
|
|
||||||
|
尤其 `UpdateExtra` 现在被多处高频调用:
|
||||||
|
|
||||||
|
- `backend/internal/service/openai_gateway_service.go:4039` 持久化 Codex rate-limit snapshot
|
||||||
|
- `backend/internal/service/ratelimit_service.go:903` 持久化 OpenAI Codex snapshot
|
||||||
|
- `backend/internal/service/ratelimit_service.go:1013` / `1025` 更新 session window utilization
|
||||||
|
|
||||||
|
这类“监控/额度快照”并不会改变账号是否可调度,却仍然走了:
|
||||||
|
|
||||||
|
- JSONB 更新
|
||||||
|
- `updated_at`
|
||||||
|
- `scheduler_outbox`
|
||||||
|
|
||||||
|
这是明显的写放大。
|
||||||
|
|
||||||
|
### 4. `scheduler_outbox` 设计偏向“每次状态变更都写一条”,高峰期会反压调度器
|
||||||
|
|
||||||
|
`backend/internal/repository/scheduler_outbox_repo.go:79` 的 `enqueueSchedulerOutbox` 非常轻,但它被大量调用。
|
||||||
|
|
||||||
|
例如:
|
||||||
|
|
||||||
|
- `UpdateExtra` 每次都 enqueue `AccountChanged`
|
||||||
|
- `BatchUpdateLastUsed` 也会 enqueue 一条 `AccountLastUsed`
|
||||||
|
- 各类账号限流、过载、错误状态切换也都会 enqueue
|
||||||
|
|
||||||
|
对应的 outbox worker 在:
|
||||||
|
|
||||||
|
- `backend/internal/service/scheduler_snapshot_service.go:199`
|
||||||
|
- `backend/internal/service/scheduler_snapshot_service.go:219`
|
||||||
|
|
||||||
|
它会不断拉取 outbox,再触发 `GetByID`、`rebuildBucket`、`loadAccountsFromDB`。
|
||||||
|
|
||||||
|
所以当高频写入导致 outbox 增长时,系统不仅多了写,还会反向带出更多读和缓存重建。
|
||||||
|
|
||||||
|
### 5. 仪表盘只有一部分走了预聚合,`models/groups/users-trend` 仍然直接扫 `usage_logs`
|
||||||
|
|
||||||
|
好消息是,`dashboard stats` 本身已经接了预聚合表:
|
||||||
|
|
||||||
|
- `backend/internal/repository/usage_log_repo.go:306`
|
||||||
|
- `backend/internal/repository/usage_log_repo.go:420`
|
||||||
|
- 预聚合表定义在 `backend/migrations/034_usage_dashboard_aggregation_tables.sql:1`
|
||||||
|
|
||||||
|
但后台慢的不是只有 stats。
|
||||||
|
|
||||||
|
`snapshot-v2` 默认会同时拉:
|
||||||
|
|
||||||
|
- stats
|
||||||
|
- trend
|
||||||
|
- model stats
|
||||||
|
|
||||||
|
见:
|
||||||
|
|
||||||
|
- `backend/internal/handler/admin/dashboard_snapshot_v2_handler.go:68`
|
||||||
|
|
||||||
|
其中:
|
||||||
|
|
||||||
|
- `GetUsageTrendWithFilters` 只有“无过滤、day/hour”时才走预聚合,见 `usage_log_repo.go:1657`
|
||||||
|
- `GetModelStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1805`
|
||||||
|
- `GetGroupStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1872`
|
||||||
|
- `GetUserUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1101`
|
||||||
|
- `GetAPIKeyUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1046`
|
||||||
|
|
||||||
|
所以线上会出现:
|
||||||
|
|
||||||
|
- stats 快
|
||||||
|
- 但 `snapshot-v2` 仍然慢
|
||||||
|
- `/admin/dashboard/users-trend` 单独也慢
|
||||||
|
|
||||||
|
这和你线上看到的日志完全一致。
|
||||||
|
|
||||||
|
### 6. 管理后台充值是“读用户 -> 整体更新用户 -> 插审计记录”
|
||||||
|
|
||||||
|
`backend/internal/service/admin_service.go:694` 的 `UpdateUserBalance` 当前流程:
|
||||||
|
|
||||||
|
1. `GetByID`
|
||||||
|
2. 在内存里改 balance
|
||||||
|
3. `userRepo.Update`
|
||||||
|
4. `redeemCodeRepo.Create` 记录 admin 调账历史
|
||||||
|
|
||||||
|
而 `userRepo.Update` 是整用户对象更新,并同步 allowed groups 事务处理:
|
||||||
|
|
||||||
|
- `backend/internal/repository/user_repo.go:118`
|
||||||
|
|
||||||
|
这个接口平时不一定重,但在数据库已经抖动时,会比一个原子 `UPDATE users SET balance = balance + $1` 更脆弱。
|
||||||
|
|
||||||
|
## 额外观察
|
||||||
|
|
||||||
|
### `ops_error_logs` 虽然已异步化,但单条写入仍然很重
|
||||||
|
|
||||||
|
错误日志中间件已经做了队列削峰:
|
||||||
|
|
||||||
|
- `backend/internal/handler/ops_error_logger.go:69`
|
||||||
|
- `backend/internal/handler/ops_error_logger.go:106`
|
||||||
|
|
||||||
|
这点方向是对的。
|
||||||
|
|
||||||
|
但落库表本身很重:
|
||||||
|
|
||||||
|
- `backend/internal/repository/ops_repo.go:23`
|
||||||
|
- `backend/migrations/033_ops_monitoring_vnext.sql:69`
|
||||||
|
- `backend/migrations/033_ops_monitoring_vnext.sql:470`
|
||||||
|
|
||||||
|
`ops_error_logs` 不仅列很多,还带了多组 B-Tree 和 trigram 索引。高错误率时,即使改成异步,也还是会把 WAL 和 I/O 压上去。
|
||||||
|
|
||||||
|
## 建议的 PR 拆分
|
||||||
|
|
||||||
|
建议拆成 4 个 PR,不要在一个 PR 里同时改数据库模型、后台查询和管理接口。
|
||||||
|
|
||||||
|
### PR 1: 收缩成功请求路径的同步写库次数
|
||||||
|
|
||||||
|
目标:把一次成功请求的同步写次数从 3 到 5 次,尽量压到 1 到 2 次。
|
||||||
|
|
||||||
|
建议改动:
|
||||||
|
|
||||||
|
1. 把 `APIKeyService.UpdateQuotaUsed` 改为单 SQL
|
||||||
|
- 新增 repo 方法,例如 `IncrementQuotaUsedAndMaybeExhaust`
|
||||||
|
- 在 SQL 里同时完成 `quota_used += ?` 和 `status = quota_exhausted`
|
||||||
|
- 返回 `key/status/quota/quota_used` 最小字段,直接失效缓存
|
||||||
|
- 删掉当前的 `Increment -> GetByID -> Update`
|
||||||
|
|
||||||
|
2. 把账号 quota 计数从 `accounts.extra` 拆出去
|
||||||
|
- 最理想:新增结构化列或独立 `account_quota_counters` 表
|
||||||
|
- 次优:至少把 `quota_used/quota_daily_used/quota_weekly_used` 从 JSONB 中剥离
|
||||||
|
|
||||||
|
3. 对“纯监控型 extra 字段”禁止 enqueue outbox
|
||||||
|
- 例如 codex snapshot、session_window_utilization
|
||||||
|
- 这些字段不影响调度,不应该触发 `SchedulerOutboxEventAccountChanged`
|
||||||
|
|
||||||
|
4. 复用现有 `DeferredService` 思路
|
||||||
|
- `last_used` 已经是批量刷盘,见 `deferred_service.go:41`
|
||||||
|
- 可继续扩展 `deferred quota snapshot flush`
|
||||||
|
|
||||||
|
预期收益:
|
||||||
|
|
||||||
|
- 直接减少 WAL 写入量
|
||||||
|
- 降低 `accounts` 热点行锁竞争
|
||||||
|
- 降低 outbox 增长速度
|
||||||
|
|
||||||
|
### PR 2: 给 dashboard 补齐预聚合/缓存,避免继续扫 `usage_logs`
|
||||||
|
|
||||||
|
目标:后台仪表盘接口不再直接扫描大表。
|
||||||
|
|
||||||
|
建议改动:
|
||||||
|
|
||||||
|
1. 为 `users-trend` / `api-keys-trend` 增加小时/天级预聚合表
|
||||||
|
2. 为 `model stats` / `group stats` 增加日级聚合表
|
||||||
|
3. `snapshot-v2` 增加分段缓存
|
||||||
|
- `stats`
|
||||||
|
- `trend`
|
||||||
|
- `models`
|
||||||
|
- `groups`
|
||||||
|
- `users_trend`
|
||||||
|
避免一个 section miss 导致整份 snapshot 重新扫库
|
||||||
|
4. 可选:把 `include_model_stats` 默认值从 `true` 改成 `false`
|
||||||
|
- 至少让默认仪表盘先恢复可用,再按需加载重模块
|
||||||
|
|
||||||
|
预期收益:
|
||||||
|
|
||||||
|
- `snapshot-v2`
|
||||||
|
- `/admin/dashboard/users-trend`
|
||||||
|
- `/admin/dashboard/api-keys-trend`
|
||||||
|
|
||||||
|
这几类接口会从“随数据量线性恶化”变成“近似固定成本”。
|
||||||
|
|
||||||
|
### PR 3: 简化管理后台充值链路
|
||||||
|
|
||||||
|
目标:管理充值/扣余额不再依赖整用户对象更新。
|
||||||
|
|
||||||
|
建议改动:
|
||||||
|
|
||||||
|
1. 新增 repo 原子方法
|
||||||
|
- `SetBalance(userID, amount)`
|
||||||
|
- `AddBalance(userID, delta)`
|
||||||
|
- `SubtractBalance(userID, delta)`
|
||||||
|
|
||||||
|
2. `UpdateUserBalance` 改为:
|
||||||
|
- 先执行原子 SQL
|
||||||
|
- 再读一次最小必要字段返回
|
||||||
|
- 审计记录改为异步或降级写
|
||||||
|
|
||||||
|
3. 审计记录建议改名或独立表
|
||||||
|
- 现在把后台调账记录塞进 `redeem_codes`,语义上不干净
|
||||||
|
|
||||||
|
预期收益:
|
||||||
|
|
||||||
|
- `/api/v1/admin/users/:id/balance` 在库抖时更稳
|
||||||
|
- 失败面缩小,不再被 allowed groups 同步事务拖累
|
||||||
|
|
||||||
|
### PR 4: 为重写路径增加“丢弃策略”和“熔断指标”
|
||||||
|
|
||||||
|
目标:高峰期先保护主链路,不让非核心写入拖死数据库。
|
||||||
|
|
||||||
|
建议改动:
|
||||||
|
|
||||||
|
1. `ops_error_logs`
|
||||||
|
- 增加采样或分级开关
|
||||||
|
- 对重复 429/5xx 做聚合计数而不是逐条插入
|
||||||
|
- 对 request body / headers 存储加更严格开关
|
||||||
|
|
||||||
|
2. `scheduler_outbox`
|
||||||
|
- 增加 coalesce/merge 机制
|
||||||
|
- 同一账号短时间内多次 `AccountChanged` 合并为一条
|
||||||
|
|
||||||
|
3. 指标补齐
|
||||||
|
- outbox backlog
|
||||||
|
- ops error queue dropped
|
||||||
|
- deferred flush lag
|
||||||
|
- account extra write QPS
|
||||||
|
|
||||||
|
## 推荐实施顺序
|
||||||
|
|
||||||
|
1. 先做 PR 1
|
||||||
|
- 这是这次线上故障的主链路
|
||||||
|
2. 再做 PR 2
|
||||||
|
- 解决后台仪表盘慢
|
||||||
|
3. 再做 PR 3
|
||||||
|
- 解决后台充值接口脆弱
|
||||||
|
4. 最后做 PR 4
|
||||||
|
- 做长期保护
|
||||||
|
|
||||||
|
## 验证方案
|
||||||
|
|
||||||
|
每个 PR 合并前都建议做同一组验证:
|
||||||
|
|
||||||
|
1. 压测成功请求链路,记录单请求 SQL 次数
|
||||||
|
2. 观测 PostgreSQL:
|
||||||
|
- `pg_stat_activity`
|
||||||
|
- `pg_stat_statements`
|
||||||
|
- `WALWrite` / `WalSync`
|
||||||
|
- 每分钟 WAL 增量
|
||||||
|
3. 观测接口:
|
||||||
|
- `/api/v1/auth/refresh`
|
||||||
|
- `/api/v1/admin/dashboard/snapshot-v2`
|
||||||
|
- `/api/v1/admin/dashboard/users-trend`
|
||||||
|
- `/api/v1/admin/users/:id/balance`
|
||||||
|
4. 观测队列:
|
||||||
|
- `ops_error_logs` queue length / dropped
|
||||||
|
- `scheduler_outbox` backlog
|
||||||
|
|
||||||
|
## 可直接作为 PR 描述的摘要
|
||||||
|
|
||||||
|
This PR reduces database write amplification on the request success path and removes several hot-path writes from `accounts.extra` + `scheduler_outbox`. It also prepares dashboard endpoints to rely on pre-aggregated data instead of scanning `usage_logs` under load. The goal is to keep admin dashboard, balance update, auth refresh, and billing-related paths stable under sustained 500+ RPS traffic.
|
||||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
|
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
@@ -162,9 +163,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
|
|||||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
|||||||
|
|
||||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
type DashboardAggregationRetentionConfig struct {
|
type DashboardAggregationRetentionConfig struct {
|
||||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
HourlyDays int `mapstructure:"hourly_days"`
|
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||||
DailyDays int `mapstructure:"daily_days"`
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageCleanupConfig 使用记录清理任务配置
|
// UsageCleanupConfig 使用记录清理任务配置
|
||||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
}
|
}
|
||||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
|||||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
}
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||||
|
}
|
||||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
}
|
}
|
||||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "dashboard aggregation disabled interval",
|
name: "dashboard aggregation disabled interval",
|
||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||||
|
|||||||
@@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
@@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: currentAPIKey,
|
APIKey: currentAPIKey,
|
||||||
User: currentAPIKey.User,
|
User: currentAPIKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: currentSubscription,
|
Subscription: currentSubscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||||
&fakeGroupRepo{group: group},
|
&fakeGroupRepo{group: group},
|
||||||
nil, // usageLogRepo
|
nil, // usageLogRepo
|
||||||
|
nil, // usageBillingRepo
|
||||||
nil, // userRepo
|
nil, // userRepo
|
||||||
nil, // userSubRepo
|
nil, // userSubRepo
|
||||||
nil, // userGroupRateRepo
|
nil, // userGroupRateRepo
|
||||||
|
|||||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
|||||||
@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.responses"),
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
@@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.messages"),
|
zap.String("component", "handler.openai_gateway.messages"),
|
||||||
@@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
|||||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
|||||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||||
return service.NewGatewayService(
|
return service.NewGatewayService(
|
||||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||||
|
|||||||
@@ -431,6 +431,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
testutil.StubGatewayCache{},
|
testutil.StubGatewayCache{},
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
|||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const usageLogsCleanupBatchSize = 10000
|
||||||
|
const usageBillingDedupCleanupBatchSize = 10000
|
||||||
|
|
||||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||||
if sqlDB == nil {
|
if sqlDB == nil {
|
||||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if r == nil || r.sql == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
loc := timezone.Location()
|
loc := timezone.Location()
|
||||||
startLocal := start.In(loc)
|
startLocal := start.In(loc)
|
||||||
endLocal := end.In(loc)
|
endLocal := end.In(loc)
|
||||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
|||||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db, ok := r.sql.(*sql.DB); ok {
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||||
|
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
|||||||
if isPartitioned {
|
if isPartitioned {
|
||||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||||
}
|
}
|
||||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
for {
|
||||||
return err
|
res, err := r.sql.ExecContext(ctx, `
|
||||||
|
WITH victims AS (
|
||||||
|
SELECT ctid
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at < $1
|
||||||
|
LIMIT $2
|
||||||
|
)
|
||||||
|
DELETE FROM usage_logs
|
||||||
|
WHERE ctid IN (SELECT ctid FROM victims)
|
||||||
|
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected < usageLogsCleanupBatchSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
for {
|
||||||
|
res, err := r.sql.ExecContext(ctx, `
|
||||||
|
WITH victims AS (
|
||||||
|
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||||
|
FROM usage_billing_dedup
|
||||||
|
WHERE created_at < $1
|
||||||
|
LIMIT $2
|
||||||
|
), archived AS (
|
||||||
|
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||||
|
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||||
|
FROM victims
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
)
|
||||||
|
DELETE FROM usage_billing_dedup
|
||||||
|
WHERE ctid IN (SELECT ctid FROM victims)
|
||||||
|
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected < usageBillingDedupCleanupBatchSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||||
|
|
||||||
|
// usage_billing_dedup: billing idempotency narrow table
|
||||||
|
var usageBillingDedupRegclass sql.NullString
|
||||||
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||||
|
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||||
|
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||||
|
|
||||||
|
var usageBillingDedupArchiveRegclass sql.NullString
|
||||||
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||||
|
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||||
|
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||||
|
|
||||||
// settings table should exist
|
// settings table should exist
|
||||||
var settingsRegclass sql.NullString
|
var settingsRegclass sql.NullString
|
||||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
err := tx.QueryRowContext(context.Background(), `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
AND tablename = $1
|
||||||
|
AND indexname = $2
|
||||||
|
)
|
||||||
|
`, table, index).Scan(&exists)
|
||||||
|
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||||
|
require.True(t, exists, "expected index %s on %s", index, table)
|
||||||
|
}
|
||||||
|
|
||||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageBillingRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||||
|
return &usageBillingRepository{db: sqlDB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||||
|
if cmd == nil {
|
||||||
|
return &service.UsageBillingApplyResult{}, nil
|
||||||
|
}
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, errors.New("usage billing repository db is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Normalize()
|
||||||
|
if cmd.RequestID == "" {
|
||||||
|
return nil, service.ErrUsageBillingRequestIDRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if tx != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !applied {
|
||||||
|
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &service.UsageBillingApplyResult{Applied: true}
|
||||||
|
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tx = nil
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||||
|
var id int64
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
RETURNING id
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
var existingFingerprint string
|
||||||
|
if err := tx.QueryRowContext(ctx, `
|
||||||
|
SELECT request_fingerprint
|
||||||
|
FROM usage_billing_dedup
|
||||||
|
WHERE request_id = $1 AND api_key_id = $2
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||||
|
return false, service.ErrUsageBillingRequestConflict
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
var archivedFingerprint string
|
||||||
|
err = tx.QueryRowContext(ctx, `
|
||||||
|
SELECT request_fingerprint
|
||||||
|
FROM usage_billing_dedup_archive
|
||||||
|
WHERE request_id = $1 AND api_key_id = $2
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||||
|
if err == nil {
|
||||||
|
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||||
|
return false, service.ErrUsageBillingRequestConflict
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||||
|
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||||
|
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.BalanceCost > 0 {
|
||||||
|
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.APIKeyQuotaCost > 0 {
|
||||||
|
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result.APIKeyQuotaExhausted = exhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.APIKeyRateLimitCost > 0 {
|
||||||
|
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||||
|
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||||
|
const updateSQL = `
|
||||||
|
UPDATE user_subscriptions us
|
||||||
|
SET
|
||||||
|
daily_usage_usd = us.daily_usage_usd + $1,
|
||||||
|
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||||
|
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||||
|
updated_at = NOW()
|
||||||
|
FROM groups g
|
||||||
|
WHERE us.id = $2
|
||||||
|
AND us.deleted_at IS NULL
|
||||||
|
AND us.group_id = g.id
|
||||||
|
AND g.deleted_at IS NULL
|
||||||
|
`
|
||||||
|
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrSubscriptionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||||
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
UPDATE users
|
||||||
|
SET balance = balance - $1,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
`, amount, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||||
|
var exhausted bool
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
|
UPDATE api_keys
|
||||||
|
SET quota_used = quota_used + $1,
|
||||||
|
status = CASE
|
||||||
|
WHEN quota > 0
|
||||||
|
AND status = $3
|
||||||
|
AND quota_used < quota
|
||||||
|
AND quota_used + $1 >= quota
|
||||||
|
THEN $4
|
||||||
|
ELSE status
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||||
|
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return exhausted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||||
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
UPDATE api_keys SET
|
||||||
|
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||||
|
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||||
|
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||||
|
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||||
|
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||||
|
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
`, cost, apiKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||||
|
rows, err := tx.QueryContext(ctx,
|
||||||
|
`UPDATE accounts SET extra = (
|
||||||
|
COALESCE(extra, '{}'::jsonb)
|
||||||
|
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||||
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
|
jsonb_build_object(
|
||||||
|
'quota_daily_used',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
THEN $1
|
||||||
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
|
'quota_daily_start',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
THEN `+nowUTC+`
|
||||||
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
|
)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
|
jsonb_build_object(
|
||||||
|
'quota_weekly_used',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
THEN $1
|
||||||
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
|
'quota_weekly_start',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
THEN `+nowUTC+`
|
||||||
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
|
)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
|
), updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
RETURNING
|
||||||
|
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||||
|
amount, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var newUsed, limit float64
|
||||||
|
if rows.Next() {
|
||||||
|
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||||
|
Name: "billing",
|
||||||
|
Quota: 1,
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-account-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
APIKeyQuotaCost: 1.25,
|
||||||
|
APIKeyRateLimitCost: 1.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result1)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
require.True(t, result1.APIKeyQuotaExhausted)
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result2)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var balance float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||||
|
require.InDelta(t, 98.75, balance, 0.000001)
|
||||||
|
|
||||||
|
var quotaUsed float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||||
|
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||||
|
|
||||||
|
var usage5h float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||||
|
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||||
|
|
||||||
|
var status string
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||||
|
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||||
|
|
||||||
|
var dedupCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||||
|
require.Equal(t, 1, dedupCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
group := mustCreateGroup(t, client, &service.Group{
|
||||||
|
Name: "usage-billing-group-" + uuid.NewString(),
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: &group.ID,
|
||||||
|
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||||
|
Name: "billing-sub",
|
||||||
|
})
|
||||||
|
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: 0,
|
||||||
|
SubscriptionID: &subscription.ID,
|
||||||
|
SubscriptionCost: 2.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var dailyUsage float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||||
|
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||||
|
Name: "billing-conflict",
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 2.50,
|
||||||
|
})
|
||||||
|
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||||
|
Name: "billing-account",
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"quota_limit": 100.0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 3.5,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var quotaUsed float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||||
|
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|
||||||
|
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||||
|
newRequestID := "dedup-new-" + uuid.NewString()
|
||||||
|
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||||
|
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||||
|
|
||||||
|
_, err := integrationDB.ExecContext(ctx, `
|
||||||
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||||
|
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||||
|
`,
|
||||||
|
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||||
|
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||||
|
|
||||||
|
var oldCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||||
|
require.Equal(t, 0, oldCount)
|
||||||
|
|
||||||
|
var newCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||||
|
require.Equal(t, 1, newCount)
|
||||||
|
|
||||||
|
var archivedCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||||
|
require.Equal(t, 1, archivedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||||
|
Name: "billing-archive",
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
|
||||||
|
_, err = integrationDB.ExecContext(ctx, `
|
||||||
|
UPDATE usage_billing_dedup
|
||||||
|
SET created_at = $1
|
||||||
|
WHERE request_id = $2 AND api_key_id = $3
|
||||||
|
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var balance float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||||
|
require.InDelta(t, 98.75, balance, 0.000001)
|
||||||
|
}
|
||||||
@@ -3,12 +3,14 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -17,11 +19,13 @@ import (
|
|||||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||||
@@ -47,18 +51,29 @@ type usageLogRepository struct {
|
|||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
|
||||||
createBatchOnce sync.Once
|
createBatchOnce sync.Once
|
||||||
createBatchCh chan usageLogCreateRequest
|
createBatchCh chan usageLogCreateRequest
|
||||||
|
bestEffortBatchOnce sync.Once
|
||||||
|
bestEffortBatchCh chan usageLogBestEffortRequest
|
||||||
|
bestEffortRecent *gocache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
usageLogCreateBatchMaxSize = 64
|
usageLogCreateBatchMaxSize = 64
|
||||||
usageLogCreateBatchWindow = 3 * time.Millisecond
|
usageLogCreateBatchWindow = 3 * time.Millisecond
|
||||||
usageLogCreateBatchQueueCap = 4096
|
usageLogCreateBatchQueueCap = 4096
|
||||||
|
usageLogCreateCancelWait = 2 * time.Second
|
||||||
|
|
||||||
|
usageLogBestEffortBatchMaxSize = 256
|
||||||
|
usageLogBestEffortBatchWindow = 20 * time.Millisecond
|
||||||
|
usageLogBestEffortBatchQueueCap = 32768
|
||||||
|
usageLogBestEffortRecentTTL = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type usageLogCreateRequest struct {
|
type usageLogCreateRequest struct {
|
||||||
log *service.UsageLog
|
log *service.UsageLog
|
||||||
|
prepared usageLogInsertPrepared
|
||||||
|
shared *usageLogCreateShared
|
||||||
resultCh chan usageLogCreateResult
|
resultCh chan usageLogCreateResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,6 +82,12 @@ type usageLogCreateResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type usageLogBestEffortRequest struct {
|
||||||
|
prepared usageLogInsertPrepared
|
||||||
|
apiKeyID int64
|
||||||
|
resultCh chan error
|
||||||
|
}
|
||||||
|
|
||||||
type usageLogInsertPrepared struct {
|
type usageLogInsertPrepared struct {
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
requestID string
|
requestID string
|
||||||
@@ -80,6 +101,25 @@ type usageLogBatchState struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type usageLogBatchRow struct {
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Inserted bool `json:"inserted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageLogCreateShared struct {
|
||||||
|
state atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
usageLogCreateStateQueued int32 = iota
|
||||||
|
usageLogCreateStateProcessing
|
||||||
|
usageLogCreateStateCompleted
|
||||||
|
usageLogCreateStateCanceled
|
||||||
|
)
|
||||||
|
|
||||||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||||||
return newUsageLogRepositoryWithSQL(client, sqlDB)
|
return newUsageLogRepositoryWithSQL(client, sqlDB)
|
||||||
}
|
}
|
||||||
@@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage
|
|||||||
if db, ok := sqlq.(*sql.DB); ok {
|
if db, ok := sqlq.(*sql.DB); ok {
|
||||||
repo.db = db
|
repo.db = db
|
||||||
}
|
}
|
||||||
|
repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute)
|
||||||
return repo
|
return repo
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
return r.createSingle(ctx, tx.Client(), log)
|
return r.createSingle(ctx, tx.Client(), log)
|
||||||
}
|
}
|
||||||
if r.db == nil {
|
|
||||||
return r.createSingle(ctx, r.sql, log)
|
|
||||||
}
|
|
||||||
requestID := strings.TrimSpace(log.RequestID)
|
requestID := strings.TrimSpace(log.RequestID)
|
||||||
if requestID == "" {
|
if requestID == "" {
|
||||||
return r.createSingle(ctx, r.sql, log)
|
return r.createSingle(ctx, r.sql, log)
|
||||||
@@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
return r.createBatched(ctx, log)
|
return r.createBatched(ctx, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error {
|
||||||
|
if log == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
|
_, err := r.createSingle(ctx, tx.Client(), log)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r.db == nil {
|
||||||
|
_, err := r.createSingle(ctx, r.sql, log)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ensureBestEffortBatcher()
|
||||||
|
if r.bestEffortBatchCh == nil {
|
||||||
|
_, err := r.createSingle(ctx, r.sql, log)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := usageLogBestEffortRequest{
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
apiKeyID: log.APIKeyID,
|
||||||
|
resultCh: make(chan error, 1),
|
||||||
|
}
|
||||||
|
if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok {
|
||||||
|
if _, exists := r.bestEffortRecent.Get(key); exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case r.bestEffortBatchCh <- req:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
return errors.New("usage log best-effort queue full")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-req.resultCh:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
|
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
|
||||||
prepared := prepareUsageLogInsert(log)
|
prepared := prepareUsageLogInsert(log)
|
||||||
if sqlq == nil {
|
if sqlq == nil {
|
||||||
sqlq = r.sql
|
sqlq = r.sql
|
||||||
}
|
}
|
||||||
|
if ctx != nil && ctx.Err() != nil {
|
||||||
|
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||||
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO usage_logs (
|
INSERT INTO usage_logs (
|
||||||
@@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
|||||||
|
|
||||||
req := usageLogCreateRequest{
|
req := usageLogCreateRequest{
|
||||||
log: log,
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
shared: &usageLogCreateShared{},
|
||||||
resultCh: make(chan usageLogCreateResult, 1),
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case r.createBatchCh <- req:
|
case r.createBatchCh <- req:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return false, ctx.Err()
|
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||||
default:
|
default:
|
||||||
return r.createSingle(ctx, r.sql, log)
|
return r.createSingle(ctx, r.sql, log)
|
||||||
}
|
}
|
||||||
@@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
|||||||
case res := <-req.resultCh:
|
case res := <-req.resultCh:
|
||||||
return res.inserted, res.err
|
return res.inserted, res.err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return false, ctx.Err()
|
if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) {
|
||||||
|
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(usageLogCreateCancelWait)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case res := <-req.resultCh:
|
||||||
|
return res.inserted, res.err
|
||||||
|
case <-timer.C:
|
||||||
|
return false, ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) ensureBestEffortBatcher() {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.bestEffortBatchOnce.Do(func() {
|
||||||
|
r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap)
|
||||||
|
go r.runBestEffortBatcher(r.db)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||||
for {
|
for {
|
||||||
first, ok := <-r.createBatchCh
|
first, ok := <-r.createBatchCh
|
||||||
@@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) {
|
||||||
|
for {
|
||||||
|
first, ok := <-r.bestEffortBatchCh
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize)
|
||||||
|
batch = append(batch, first)
|
||||||
|
|
||||||
|
timer := time.NewTimer(usageLogBestEffortBatchWindow)
|
||||||
|
bestEffortLoop:
|
||||||
|
for len(batch) < usageLogBestEffortBatchMaxSize {
|
||||||
|
select {
|
||||||
|
case req, ok := <-r.bestEffortBatchCh:
|
||||||
|
if !ok {
|
||||||
|
break bestEffortLoop
|
||||||
|
}
|
||||||
|
batch = append(batch, req)
|
||||||
|
case <-timer.C:
|
||||||
|
break bestEffortLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.flushBestEffortBatch(db, batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
|
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
|
||||||
if len(batch) == 0 {
|
if len(batch) == 0 {
|
||||||
return
|
return
|
||||||
@@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
|||||||
|
|
||||||
for _, req := range batch {
|
for _, req := range batch {
|
||||||
if req.log == nil {
|
if req.log == nil {
|
||||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
prepared := prepareUsageLogInsert(req.log)
|
if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) {
|
||||||
|
if req.shared.state.Load() == usageLogCreateStateCanceled {
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||||
|
inserted: false,
|
||||||
|
err: service.MarkUsageLogCreateNotPersisted(context.Canceled),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prepared := req.prepared
|
||||||
if prepared.requestID == "" {
|
if prepared.requestID == "" {
|
||||||
fallback = append(fallback, req)
|
fallback = append(fallback, req)
|
||||||
continue
|
continue
|
||||||
@@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(uniqueOrder) > 0 {
|
if len(uniqueOrder) > 0 {
|
||||||
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, key := range uniqueOrder {
|
if safeFallback {
|
||||||
fallback = append(fallback, requestsByKey[key]...)
|
for _, key := range uniqueOrder {
|
||||||
|
fallback = append(fallback, requestsByKey[key]...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, key := range uniqueOrder {
|
||||||
|
reqs := requestsByKey[key]
|
||||||
|
state, hasState := stateMap[key]
|
||||||
|
inserted := insertedMap[key]
|
||||||
|
for idx, req := range reqs {
|
||||||
|
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||||
|
if hasState {
|
||||||
|
req.log.ID = state.ID
|
||||||
|
req.log.CreatedAt = state.CreatedAt
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case inserted && idx == 0:
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil})
|
||||||
|
case inserted:
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||||
|
case hasState:
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||||
|
case idx == 0:
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err})
|
||||||
|
default:
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, key := range uniqueOrder {
|
for _, key := range uniqueOrder {
|
||||||
@@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
|||||||
state, ok := stateMap[key]
|
state, ok := stateMap[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
for _, req := range reqs {
|
for _, req := range reqs {
|
||||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||||
inserted: false,
|
inserted: false,
|
||||||
err: fmt.Errorf("usage log batch state missing for key=%s", key),
|
err: fmt.Errorf("usage log batch state missing for key=%s", key),
|
||||||
})
|
})
|
||||||
@@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
|||||||
req.log.ID = state.ID
|
req.log.ID = state.ID
|
||||||
req.log.CreatedAt = state.CreatedAt
|
req.log.CreatedAt = state.CreatedAt
|
||||||
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||||
inserted: idx == 0 && insertedMap[key],
|
inserted: idx == 0 && insertedMap[key],
|
||||||
err: nil,
|
err: nil,
|
||||||
})
|
})
|
||||||
@@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
for _, req := range fallback {
|
for _, req := range fallback {
|
||||||
|
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
inserted, err := r.createSingle(fallbackCtx, db, req.log)
|
inserted, err := r.createSingle(fallbackCtx, db, req.log)
|
||||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
|
cancel()
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
|
func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type bestEffortGroup struct {
|
||||||
|
prepared usageLogInsertPrepared
|
||||||
|
apiKeyID int64
|
||||||
|
key string
|
||||||
|
reqs []usageLogBestEffortRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsByKey := make(map[string]*bestEffortGroup, len(batch))
|
||||||
|
groupOrder := make([]*bestEffortGroup, 0, len(batch))
|
||||||
|
preparedList := make([]usageLogInsertPrepared, 0, len(batch))
|
||||||
|
|
||||||
|
for idx, req := range batch {
|
||||||
|
prepared := req.prepared
|
||||||
|
key := fmt.Sprintf("__best_effort_%d", idx)
|
||||||
|
if prepared.requestID != "" {
|
||||||
|
key = usageLogBatchKey(prepared.requestID, req.apiKeyID)
|
||||||
|
}
|
||||||
|
group, exists := groupsByKey[key]
|
||||||
|
if !exists {
|
||||||
|
group = &bestEffortGroup{
|
||||||
|
prepared: prepared,
|
||||||
|
apiKeyID: req.apiKeyID,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
groupsByKey[key] = group
|
||||||
|
groupOrder = append(groupOrder, group)
|
||||||
|
preparedList = append(preparedList, prepared)
|
||||||
|
}
|
||||||
|
group.reqs = append(group.reqs, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preparedList) == 0 {
|
||||||
|
for _, req := range batch {
|
||||||
|
sendUsageLogBestEffortResult(req.resultCh, nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
query, args := buildUsageLogBestEffortInsertQuery(preparedList)
|
||||||
|
if _, err := db.ExecContext(ctx, query, args...); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err)
|
||||||
|
for _, group := range groupOrder {
|
||||||
|
singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared)
|
||||||
|
if singleErr != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr)
|
||||||
|
} else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
|
||||||
|
r.bestEffortRecent.SetDefault(group.key, struct{}{})
|
||||||
|
}
|
||||||
|
for _, req := range group.reqs {
|
||||||
|
sendUsageLogBestEffortResult(req.resultCh, singleErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, group := range groupOrder {
|
||||||
|
if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
|
||||||
|
r.bestEffortRecent.SetDefault(group.key, struct{}{})
|
||||||
|
}
|
||||||
|
for _, req := range group.reqs {
|
||||||
|
sendUsageLogBestEffortResult(req.resultCh, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendUsageLogBestEffortResult(ch chan error, err error) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case ch <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) {
|
||||||
|
if req.shared != nil {
|
||||||
|
req.shared.state.Store(usageLogCreateStateCompleted)
|
||||||
|
}
|
||||||
|
sendUsageLogCreateResult(req.resultCh, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) {
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
return map[string]bool{}, map[string]usageLogBatchState{}, nil
|
return map[string]bool{}, map[string]usageLogBatchState{}, false, nil
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
|
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
|
||||||
rows, err := db.QueryContext(ctx, query, args...)
|
var payload []byte
|
||||||
if err != nil {
|
if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, true, err
|
||||||
|
}
|
||||||
|
var rows []usageLogBatchRow
|
||||||
|
if err := json.Unmarshal(payload, &rows); err != nil {
|
||||||
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
insertedMap := make(map[string]bool, len(keys))
|
insertedMap := make(map[string]bool, len(keys))
|
||||||
for rows.Next() {
|
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||||
var (
|
for _, row := range rows {
|
||||||
requestID string
|
key := usageLogBatchKey(row.RequestID, row.APIKeyID)
|
||||||
apiKeyID int64
|
insertedMap[key] = row.Inserted
|
||||||
id int64
|
stateMap[key] = usageLogBatchState{
|
||||||
createdAt time.Time
|
ID: row.ID,
|
||||||
)
|
CreatedAt: row.CreatedAt,
|
||||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
}
|
||||||
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
|
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if len(stateMap) != len(keys) {
|
||||||
_ = rows.Close()
|
return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys))
|
||||||
return nil, nil, err
|
|
||||||
}
|
}
|
||||||
_ = rows.Close()
|
return insertedMap, stateMap, false, nil
|
||||||
|
|
||||||
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return insertedMap, stateMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
|
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
|
||||||
var query strings.Builder
|
var query strings.Builder
|
||||||
_, _ = query.WriteString(`
|
_, _ = query.WriteString(`
|
||||||
|
WITH input (
|
||||||
|
input_idx,
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
group_id,
|
||||||
|
subscription_id,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_creation_5m_tokens,
|
||||||
|
cache_creation_1h_tokens,
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
|
billing_type,
|
||||||
|
request_type,
|
||||||
|
stream,
|
||||||
|
openai_ws_mode,
|
||||||
|
duration_ms,
|
||||||
|
first_token_ms,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
) AS (VALUES `)
|
||||||
|
|
||||||
|
args := make([]any, 0, len(keys)*37)
|
||||||
|
argPos := 1
|
||||||
|
for idx, key := range keys {
|
||||||
|
if idx > 0 {
|
||||||
|
_, _ = query.WriteString(",")
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString("(")
|
||||||
|
_, _ = query.WriteString("$")
|
||||||
|
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||||
|
args = append(args, idx)
|
||||||
|
argPos++
|
||||||
|
prepared := preparedByKey[key]
|
||||||
|
for i := 0; i < len(prepared.args); i++ {
|
||||||
|
_, _ = query.WriteString(",")
|
||||||
|
_, _ = query.WriteString("$")
|
||||||
|
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||||
|
argPos++
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString(")")
|
||||||
|
args = append(args, prepared.args...)
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString(`
|
||||||
|
),
|
||||||
|
inserted AS (
|
||||||
|
INSERT INTO usage_logs (
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
group_id,
|
||||||
|
subscription_id,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_creation_5m_tokens,
|
||||||
|
cache_creation_1h_tokens,
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
|
billing_type,
|
||||||
|
request_type,
|
||||||
|
stream,
|
||||||
|
openai_ws_mode,
|
||||||
|
duration_ms,
|
||||||
|
first_token_ms,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
group_id,
|
||||||
|
subscription_id,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_creation_5m_tokens,
|
||||||
|
cache_creation_1h_tokens,
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
|
billing_type,
|
||||||
|
request_type,
|
||||||
|
stream,
|
||||||
|
openai_ws_mode,
|
||||||
|
duration_ms,
|
||||||
|
first_token_ms,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
FROM input
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO UPDATE
|
||||||
|
SET request_id = usage_logs.request_id
|
||||||
|
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
|
||||||
|
)
|
||||||
|
SELECT COALESCE(
|
||||||
|
json_agg(
|
||||||
|
json_build_object(
|
||||||
|
'request_id', inserted.request_id,
|
||||||
|
'api_key_id', inserted.api_key_id,
|
||||||
|
'id', inserted.id,
|
||||||
|
'created_at', inserted.created_at,
|
||||||
|
'inserted', inserted.inserted
|
||||||
|
)
|
||||||
|
ORDER BY input.input_idx
|
||||||
|
),
|
||||||
|
'[]'::json
|
||||||
|
)
|
||||||
|
FROM input
|
||||||
|
JOIN inserted
|
||||||
|
ON inserted.request_id = input.request_id
|
||||||
|
AND inserted.api_key_id = input.api_key_id
|
||||||
|
`)
|
||||||
|
return query.String(), args
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) {
|
||||||
|
var query strings.Builder
|
||||||
|
_, _ = query.WriteString(`
|
||||||
|
WITH input (
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
group_id,
|
||||||
|
subscription_id,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_creation_5m_tokens,
|
||||||
|
cache_creation_1h_tokens,
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
|
billing_type,
|
||||||
|
request_type,
|
||||||
|
stream,
|
||||||
|
openai_ws_mode,
|
||||||
|
duration_ms,
|
||||||
|
first_token_ms,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
) AS (VALUES `)
|
||||||
|
|
||||||
|
args := make([]any, 0, len(preparedList)*36)
|
||||||
|
argPos := 1
|
||||||
|
for idx, prepared := range preparedList {
|
||||||
|
if idx > 0 {
|
||||||
|
_, _ = query.WriteString(",")
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString("(")
|
||||||
|
for i := 0; i < len(prepared.args); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
_, _ = query.WriteString(",")
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString("$")
|
||||||
|
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||||
|
argPos++
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString(")")
|
||||||
|
args = append(args, prepared.args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = query.WriteString(`
|
||||||
|
)
|
||||||
INSERT INTO usage_logs (
|
INSERT INTO usage_logs (
|
||||||
user_id,
|
user_id,
|
||||||
api_key_id,
|
api_key_id,
|
||||||
@@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) VALUES `)
|
)
|
||||||
|
SELECT
|
||||||
args := make([]any, 0, len(keys)*36)
|
user_id,
|
||||||
argPos := 1
|
api_key_id,
|
||||||
for idx, key := range keys {
|
account_id,
|
||||||
if idx > 0 {
|
request_id,
|
||||||
_, _ = query.WriteString(",")
|
model,
|
||||||
}
|
group_id,
|
||||||
_, _ = query.WriteString("(")
|
subscription_id,
|
||||||
prepared := preparedByKey[key]
|
input_tokens,
|
||||||
for i := 0; i < len(prepared.args); i++ {
|
output_tokens,
|
||||||
if i > 0 {
|
cache_creation_tokens,
|
||||||
_, _ = query.WriteString(",")
|
cache_read_tokens,
|
||||||
}
|
cache_creation_5m_tokens,
|
||||||
_, _ = query.WriteString("$")
|
cache_creation_1h_tokens,
|
||||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
input_cost,
|
||||||
argPos++
|
output_cost,
|
||||||
}
|
cache_creation_cost,
|
||||||
_, _ = query.WriteString(")")
|
cache_read_cost,
|
||||||
args = append(args, prepared.args...)
|
total_cost,
|
||||||
}
|
actual_cost,
|
||||||
_, _ = query.WriteString(`
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
|
billing_type,
|
||||||
|
request_type,
|
||||||
|
stream,
|
||||||
|
openai_ws_mode,
|
||||||
|
duration_ms,
|
||||||
|
first_token_ms,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
FROM input
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING request_id, api_key_id, id, created_at
|
|
||||||
`)
|
`)
|
||||||
|
|
||||||
return query.String(), args
|
return query.String(), args
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
|
func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error {
|
||||||
var query strings.Builder
|
_, err := sqlq.ExecContext(ctx, `
|
||||||
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
|
INSERT INTO usage_logs (
|
||||||
args := make([]any, 0, len(keys)*2)
|
user_id,
|
||||||
argPos := 1
|
api_key_id,
|
||||||
for idx, key := range keys {
|
account_id,
|
||||||
if idx > 0 {
|
request_id,
|
||||||
_, _ = query.WriteString(" OR ")
|
model,
|
||||||
}
|
group_id,
|
||||||
prepared := preparedByKey[key]
|
subscription_id,
|
||||||
apiKeyID := prepared.args[1]
|
input_tokens,
|
||||||
_, _ = query.WriteString("(request_id = $")
|
output_tokens,
|
||||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
cache_creation_tokens,
|
||||||
_, _ = query.WriteString(" AND api_key_id = $")
|
cache_read_tokens,
|
||||||
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
|
cache_creation_5m_tokens,
|
||||||
_, _ = query.WriteString(")")
|
cache_creation_1h_tokens,
|
||||||
args = append(args, prepared.requestID, apiKeyID)
|
input_cost,
|
||||||
argPos += 2
|
output_cost,
|
||||||
}
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
rows, err := db.QueryContext(ctx, query.String(), args...)
|
total_cost,
|
||||||
if err != nil {
|
actual_cost,
|
||||||
return nil, err
|
rate_multiplier,
|
||||||
}
|
account_rate_multiplier,
|
||||||
defer func() { _ = rows.Close() }()
|
billing_type,
|
||||||
|
request_type,
|
||||||
stateMap := make(map[string]usageLogBatchState, len(keys))
|
stream,
|
||||||
for rows.Next() {
|
openai_ws_mode,
|
||||||
var (
|
duration_ms,
|
||||||
requestID string
|
first_token_ms,
|
||||||
apiKeyID int64
|
user_agent,
|
||||||
id int64
|
ip_address,
|
||||||
createdAt time.Time
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5,
|
||||||
|
$6, $7,
|
||||||
|
$8, $9, $10, $11,
|
||||||
|
$12, $13,
|
||||||
|
$14, $15, $16, $17, $18, $19,
|
||||||
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||||
)
|
)
|
||||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
return nil, err
|
`, prepared.args...)
|
||||||
}
|
return err
|
||||||
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
|
|
||||||
ID: id,
|
|
||||||
CreatedAt: createdAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return stateMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||||
@@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) {
|
||||||
|
requestID = strings.TrimSpace(requestID)
|
||||||
|
if requestID == "" || r == nil || r.bestEffortRecent == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return usageLogBatchKey(requestID, apiKeyID), true
|
||||||
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||||
|
|||||||
@@ -183,6 +183,214 @@ func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
|||||||
require.Equal(t, 1, count)
|
require.Equal(t, 1, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
const total = 8
|
||||||
|
batch := make([]usageLogCreateRequest, 0, total)
|
||||||
|
logs := make([]*service.UsageLog, 0, total)
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10 + i,
|
||||||
|
OutputTokens: 20 + i,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
logs = append(logs, log)
|
||||||
|
batch = append(batch, usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.flushCreateBatch(integrationDB, batch)
|
||||||
|
|
||||||
|
insertedCount := 0
|
||||||
|
var firstID int64
|
||||||
|
for idx, req := range batch {
|
||||||
|
res := <-req.resultCh
|
||||||
|
require.NoError(t, res.err)
|
||||||
|
if res.inserted {
|
||||||
|
insertedCount++
|
||||||
|
}
|
||||||
|
require.NotZero(t, logs[idx].ID)
|
||||||
|
if idx == 0 {
|
||||||
|
firstID = logs[idx].ID
|
||||||
|
} else {
|
||||||
|
require.Equal(t, firstID, logs[idx].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, insertedCount)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||||
|
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
var count int
|
||||||
|
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||||
|
return err == nil && count == 1
|
||||||
|
}, 3*time.Second, 20*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, inserted)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := <-repo.createBatchCh
|
||||||
|
require.NotNil(t, req.shared)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := <-errCh
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||||
|
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
req := usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
shared: &usageLogCreateShared{},
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
}
|
||||||
|
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||||
|
|
||||||
|
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||||
|
|
||||||
|
res := <-req.resultCh
|
||||||
|
require.False(t, res.inserted)
|
||||||
|
require.Error(t, res.err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAnnouncementRepository,
|
NewAnnouncementRepository,
|
||||||
NewAnnouncementReadRepository,
|
NewAnnouncementReadRepository,
|
||||||
NewUsageLogRepository,
|
NewUsageLogRepository,
|
||||||
|
NewUsageBillingRepository,
|
||||||
NewIdempotencyRepository,
|
NewIdempotencyRepository,
|
||||||
NewUsageCleanupRepository,
|
NewUsageCleanupRepository,
|
||||||
NewDashboardAggregationRepository,
|
NewDashboardAggregationRepository,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
|||||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||||
|
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
|||||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||||
|
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||||
|
|
||||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||||
if aggErr != nil {
|
if aggErr != nil {
|
||||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
|||||||
if usageErr != nil {
|
if usageErr != nil {
|
||||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||||
}
|
}
|
||||||
if aggErr == nil && usageErr == nil {
|
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||||
|
if dedupErr != nil {
|
||||||
|
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||||
|
}
|
||||||
|
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||||
s.lastRetentionCleanup.Store(now)
|
s.lastRetentionCleanup.Store(now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ import (
|
|||||||
|
|
||||||
type dashboardAggregationRepoTestStub struct {
|
type dashboardAggregationRepoTestStub struct {
|
||||||
aggregateCalls int
|
aggregateCalls int
|
||||||
|
recomputeCalls int
|
||||||
|
cleanupUsageCalls int
|
||||||
|
cleanupDedupCalls int
|
||||||
|
ensurePartitionCalls int
|
||||||
lastStart time.Time
|
lastStart time.Time
|
||||||
lastEnd time.Time
|
lastEnd time.Time
|
||||||
watermark time.Time
|
watermark time.Time
|
||||||
aggregateErr error
|
aggregateErr error
|
||||||
cleanupAggregatesErr error
|
cleanupAggregatesErr error
|
||||||
cleanupUsageErr error
|
cleanupUsageErr error
|
||||||
|
cleanupDedupErr error
|
||||||
|
ensurePartitionErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
s.recomputeCalls++
|
||||||
return s.AggregateRange(ctx, start, end)
|
return s.AggregateRange(ctx, start, end)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
s.cleanupUsageCalls++
|
||||||
return s.cleanupUsageErr
|
return s.cleanupUsageErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
s.cleanupDedupCalls++
|
||||||
|
return s.cleanupDedupErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
return nil
|
s.ensurePartitionCalls++
|
||||||
|
return s.ensurePartitionErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
|||||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||||
|
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
UsageBillingDedupDays: 2,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.runScheduledAggregation()
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||||
|
require.Equal(t, 1, repo.aggregateCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||||
|
|||||||
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := &GatewayService{
|
cfg := &config.Config{
|
||||||
cfg: &config.Config{
|
Gateway: config.GatewayConfig{
|
||||||
Gateway: config.GatewayConfig{
|
MaxLineSize: defaultMaxLineSize,
|
||||||
MaxLineSize: defaultMaxLineSize,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
httpUpstream: upstream,
|
}
|
||||||
rateLimitService: &RateLimitService{},
|
svc := &GatewayService{
|
||||||
deferredService: &DeferredService{},
|
cfg: cfg,
|
||||||
billingCacheService: nil,
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
deferredService: &DeferredService{},
|
||||||
|
billingCacheService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := &GatewayService{
|
cfg := &config.Config{
|
||||||
cfg: &config.Config{
|
Gateway: config.GatewayConfig{
|
||||||
Gateway: config.GatewayConfig{
|
MaxLineSize: defaultMaxLineSize,
|
||||||
MaxLineSize: defaultMaxLineSize,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
httpUpstream: upstream,
|
}
|
||||||
rateLimitService: &RateLimitService{},
|
svc := &GatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
|||||||
require.Equal(t, 5, result.usage.OutputTokens)
|
require.Equal(t, 5, result.usage.OutputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
|
|||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
<-done
|
<-done
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
require.Equal(t, 9, result.usage.InputTokens)
|
require.Equal(t, 9, result.usage.InputTokens)
|
||||||
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream usage incomplete")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
}
|
}
|
||||||
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
require.Equal(t, 8, result.usage.InputTokens)
|
require.Equal(t, 8, result.usage.InputTokens)
|
||||||
|
|||||||
261
backend/internal/service/gateway_record_usage_test.go
Normal file
261
backend/internal/service/gateway_record_usage_test.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Default.RateMultiplier = 1.1
|
||||||
|
return NewGatewayService(
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
usageRepo,
|
||||||
|
nil,
|
||||||
|
userRepo,
|
||||||
|
subRepo,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
NewBillingService(cfg, nil),
|
||||||
|
nil,
|
||||||
|
&BillingCacheService{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&DeferredService{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
svc.usageBillingRepo = billingRepo
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_detached_ctx",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 501,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.NoError(t, userRepo.lastCtxErr)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_payload_hash",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
RequestPayloadHash: payloadHash,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
|
||||||
|
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_payload_fallback",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_not_persisted",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 503,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 603},
|
||||||
|
Account: &Account{ID: 703},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_long_context_detached_ctx",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 12,
|
||||||
|
OutputTokens: 8,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 502,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 602},
|
||||||
|
Account: &Account{ID: 702},
|
||||||
|
LongContextThreshold: 200000,
|
||||||
|
LongContextMultiplier: 2,
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.NoError(t, userRepo.lastCtxErr)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
|
||||||
|
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 504},
|
||||||
|
User: &User{ID: 604},
|
||||||
|
Account: &Account{ID: 704},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_billing_fail",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 505},
|
||||||
|
User: &User{ID: 605},
|
||||||
|
Account: &Account{ID: 705},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 0, usageRepo.calls)
|
||||||
|
}
|
||||||
@@ -50,6 +50,7 @@ const (
|
|||||||
|
|
||||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||||
defaultModelsListCacheTTL = 15 * time.Second
|
defaultModelsListCacheTTL = 15 * time.Second
|
||||||
|
postUsageBillingTimeout = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
|||||||
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool {
|
||||||
|
return usage != nil && (usage.InputTokens > 0 ||
|
||||||
|
usage.OutputTokens > 0 ||
|
||||||
|
usage.CacheCreationInputTokens > 0 ||
|
||||||
|
usage.CacheReadInputTokens > 0 ||
|
||||||
|
usage.CacheCreation5mTokens > 0 ||
|
||||||
|
usage.CacheCreation1hTokens > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool {
|
||||||
|
return usage != nil && (usage.InputTokens > 0 ||
|
||||||
|
usage.OutputTokens > 0 ||
|
||||||
|
usage.CacheCreationInputTokens > 0 ||
|
||||||
|
usage.CacheReadInputTokens > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIStreamEventIsTerminal(data string) bool {
|
||||||
|
trimmed := strings.TrimSpace(data)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if trimmed == "[DONE]" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch gjson.Get(trimmed, "type").String() {
|
||||||
|
case "response.completed", "response.done", "response.failed":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func anthropicStreamEventIsTerminal(eventName, data string) bool {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(data)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if trimmed == "[DONE]" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return gjson.Get(trimmed, "type").String() == "message_stop"
|
||||||
|
}
|
||||||
|
|
||||||
func cloneStringSlice(src []string) []string {
|
func cloneStringSlice(src []string) []string {
|
||||||
if len(src) == 0 {
|
if len(src) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -504,6 +551,7 @@ type GatewayService struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
usageLogRepo UsageLogRepository
|
usageLogRepo UsageLogRepository
|
||||||
|
usageBillingRepo UsageBillingRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
userGroupRateRepo UserGroupRateRepository
|
userGroupRateRepo UserGroupRateRepository
|
||||||
@@ -537,6 +585,7 @@ func NewGatewayService(
|
|||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
groupRepo GroupRepository,
|
groupRepo GroupRepository,
|
||||||
usageLogRepo UsageLogRepository,
|
usageLogRepo UsageLogRepository,
|
||||||
|
usageBillingRepo UsageBillingRepository,
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
userGroupRateRepo UserGroupRateRepository,
|
userGroupRateRepo UserGroupRateRepository,
|
||||||
@@ -563,6 +612,7 @@ func NewGatewayService(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
|
usageBillingRepo: usageBillingRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
userGroupRateRepo: userGroupRateRepo,
|
userGroupRateRepo: userGroupRateRepo,
|
||||||
@@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
retryStart := time.Now()
|
retryStart := time.Now()
|
||||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// also downgrade tool_use/tool_result blocks to text.
|
// also downgrade tool_use/tool_result blocks to text.
|
||||||
|
|
||||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
|
releaseRetryCtx()
|
||||||
if buildErr == nil {
|
if buildErr == nil {
|
||||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
@@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||||
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
|
releaseRetryCtx2()
|
||||||
if buildErr2 == nil {
|
if buildErr2 == nil {
|
||||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||||
if retryErr2 == nil {
|
if retryErr2 == nil {
|
||||||
@@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
|
releaseBudgetRetryCtx()
|
||||||
if buildErr == nil {
|
if buildErr == nil {
|
||||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
@@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
|||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
retryStart := time.Now()
|
retryStart := time.Now()
|
||||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token)
|
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
usage := &ClaudeUsage{}
|
usage := &ClaudeUsage{}
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
clientDisconnected := false
|
clientDisconnected := false
|
||||||
|
sawTerminalEvent := false
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
if !sawTerminalEvent {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
if sawTerminalEvent {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
|
}
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err)
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||||
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
|
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
@@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
line := ev.line
|
line := ev.line
|
||||||
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
||||||
trimmed := strings.TrimSpace(data)
|
trimmed := strings.TrimSpace(data)
|
||||||
|
if anthropicStreamEventIsTerminal("", trimmed) {
|
||||||
|
sawTerminalEvent = true
|
||||||
|
}
|
||||||
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
s.parseSSEUsagePassthrough(data, usage)
|
s.parseSSEUsagePassthrough(data, usage)
|
||||||
|
} else {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") {
|
||||||
|
sawTerminalEvent = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
@@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model)
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
@@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||||
|
sawTerminalEvent := false
|
||||||
|
|
||||||
pendingEventLines := make([]string, 0, 4)
|
pendingEventLines := make([]string, 0, 4)
|
||||||
|
|
||||||
@@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
if dataLine == "[DONE]" {
|
if dataLine == "[DONE]" {
|
||||||
|
sawTerminalEvent = true
|
||||||
block := ""
|
block := ""
|
||||||
if eventName != "" {
|
if eventName != "" {
|
||||||
block = "event: " + eventName + "\n"
|
block = "event: " + eventName + "\n"
|
||||||
@@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
usagePatch := s.extractSSEUsagePatch(event)
|
usagePatch := s.extractSSEUsagePatch(event)
|
||||||
|
if anthropicStreamEventIsTerminal(eventName, dataLine) {
|
||||||
|
sawTerminalEvent = true
|
||||||
|
}
|
||||||
if !eventChanged {
|
if !eventChanged {
|
||||||
block := ""
|
block := ""
|
||||||
if eventName != "" {
|
if eventName != "" {
|
||||||
@@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
// 上游完成,返回结果
|
// 上游完成,返回结果
|
||||||
|
if !sawTerminalEvent {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
if sawTerminalEvent {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
|
}
|
||||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
// 客户端未断开,正常的错误处理
|
// 客户端未断开,正常的错误处理
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
@@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||||
logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage")
|
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
@@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
|||||||
|
|
||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
type RecordUsageInput struct {
|
type RecordUsageInput struct {
|
||||||
Result *ForwardResult
|
Result *ForwardResult
|
||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||||
|
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||||
@@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface {
|
|||||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type apiKeyAuthCacheInvalidator interface {
|
||||||
|
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageLogBestEffortWriter interface {
|
||||||
|
CreateBestEffort(ctx context.Context, log *UsageLog) error
|
||||||
|
}
|
||||||
|
|
||||||
// postUsageBillingParams 统一扣费所需的参数
|
// postUsageBillingParams 统一扣费所需的参数
|
||||||
type postUsageBillingParams struct {
|
type postUsageBillingParams struct {
|
||||||
Cost *CostBreakdown
|
Cost *CostBreakdown
|
||||||
@@ -6581,6 +6668,7 @@ type postUsageBillingParams struct {
|
|||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription
|
Subscription *UserSubscription
|
||||||
|
RequestPayloadHash string
|
||||||
IsSubscriptionBill bool
|
IsSubscriptionBill bool
|
||||||
AccountRateMultiplier float64
|
AccountRateMultiplier float64
|
||||||
APIKeyService APIKeyQuotaUpdater
|
APIKeyService APIKeyQuotaUpdater
|
||||||
@@ -6592,19 +6680,22 @@ type postUsageBillingParams struct {
|
|||||||
// - API Key 限速用量更新
|
// - API Key 限速用量更新
|
||||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||||
|
billingCtx, cancel := detachedBillingContext(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
cost := p.Cost
|
cost := p.Cost
|
||||||
|
|
||||||
// 1. 订阅 / 余额扣费
|
// 1. 订阅 / 余额扣费
|
||||||
if p.IsSubscriptionBill {
|
if p.IsSubscriptionBill {
|
||||||
if cost.TotalCost > 0 {
|
if cost.TotalCost > 0 {
|
||||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||||
}
|
}
|
||||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if cost.ActualCost > 0 {
|
if cost.ActualCost > 0 {
|
||||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||||
}
|
}
|
||||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||||
@@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
|||||||
|
|
||||||
// 2. API Key 配额
|
// 2. API Key 配额
|
||||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. API Key 限速用量
|
// 3. API Key 限速用量
|
||||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
}
|
}
|
||||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 更新账号最近使用时间
|
finalizePostUsageBilling(p, deps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||||
|
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||||
|
return requestID
|
||||||
|
}
|
||||||
|
if ctx != nil {
|
||||||
|
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||||
|
return "client:" + strings.TrimSpace(clientRequestID)
|
||||||
|
}
|
||||||
|
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||||
|
return "local:" + strings.TrimSpace(requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
||||||
|
if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" {
|
||||||
|
return payloadHash
|
||||||
|
}
|
||||||
|
if ctx != nil {
|
||||||
|
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||||
|
return "client:" + strings.TrimSpace(clientRequestID)
|
||||||
|
}
|
||||||
|
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||||
|
return "local:" + strings.TrimSpace(requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand {
|
||||||
|
if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: p.APIKey.ID,
|
||||||
|
UserID: p.User.ID,
|
||||||
|
AccountID: p.Account.ID,
|
||||||
|
AccountType: p.Account.Type,
|
||||||
|
RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash),
|
||||||
|
}
|
||||||
|
if usageLog != nil {
|
||||||
|
cmd.Model = usageLog.Model
|
||||||
|
cmd.BillingType = usageLog.BillingType
|
||||||
|
cmd.InputTokens = usageLog.InputTokens
|
||||||
|
cmd.OutputTokens = usageLog.OutputTokens
|
||||||
|
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||||
|
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||||
|
cmd.ImageCount = usageLog.ImageCount
|
||||||
|
if usageLog.MediaType != nil {
|
||||||
|
cmd.MediaType = *usageLog.MediaType
|
||||||
|
}
|
||||||
|
if usageLog.ServiceTier != nil {
|
||||||
|
cmd.ServiceTier = *usageLog.ServiceTier
|
||||||
|
}
|
||||||
|
if usageLog.ReasoningEffort != nil {
|
||||||
|
cmd.ReasoningEffort = *usageLog.ReasoningEffort
|
||||||
|
}
|
||||||
|
if usageLog.SubscriptionID != nil {
|
||||||
|
cmd.SubscriptionID = usageLog.SubscriptionID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
|
||||||
|
cmd.SubscriptionID = &p.Subscription.ID
|
||||||
|
cmd.SubscriptionCost = p.Cost.TotalCost
|
||||||
|
} else if p.Cost.ActualCost > 0 {
|
||||||
|
cmd.BalanceCost = p.Cost.ActualCost
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||||
|
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||||
|
}
|
||||||
|
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||||
|
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||||
|
}
|
||||||
|
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||||
|
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Normalize()
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) {
|
||||||
|
if p == nil || deps == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := buildUsageBillingCommand(requestID, usageLog, p)
|
||||||
|
if cmd == nil || cmd.RequestID == "" || repo == nil {
|
||||||
|
postUsageBilling(ctx, p, deps)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
billingCtx, cancel := detachedBillingContext(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, err := repo.Apply(billingCtx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil || !result.Applied {
|
||||||
|
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.APIKeyQuotaExhausted {
|
||||||
|
if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" {
|
||||||
|
invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finalizePostUsageBilling(p, deps)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
||||||
|
if p == nil || p.Cost == nil || deps == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.IsSubscriptionBill {
|
||||||
|
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
|
||||||
|
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
|
||||||
|
}
|
||||||
|
} else if p.Cost.ActualCost > 0 && p.User != nil {
|
||||||
|
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() {
|
||||||
|
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost)
|
||||||
|
}
|
||||||
|
|
||||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
base := context.Background()
|
||||||
|
if ctx != nil {
|
||||||
|
base = context.WithoutCancel(ctx)
|
||||||
|
}
|
||||||
|
return context.WithTimeout(base, postUsageBillingTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||||
|
if !stream {
|
||||||
|
return ctx, func() {}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
return context.Background(), func() {}
|
||||||
|
}
|
||||||
|
return context.WithoutCancel(ctx), func() {}
|
||||||
|
}
|
||||||
|
|
||||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||||
type billingDeps struct {
|
type billingDeps struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) {
|
||||||
|
if repo == nil || usageLog == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usageCtx, cancel := detachedBillingContext(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
||||||
|
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
||||||
|
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||||
|
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
||||||
|
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := repo.Create(usageCtx, usageLog); err != nil {
|
||||||
|
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
@@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
mediaType = &result.MediaType
|
mediaType = &result.MediaType
|
||||||
}
|
}
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
|
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
@@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
usageLog.SubscriptionID = &subscription.ID
|
usageLog.SubscriptionID = &subscription.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldBill := inserted || err != nil
|
billingErr := func() error {
|
||||||
|
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||||
if shouldBill {
|
|
||||||
postUsageBilling(ctx, &postUsageBillingParams{
|
|
||||||
Cost: cost,
|
Cost: cost,
|
||||||
User: user,
|
User: user,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||||
IsSubscriptionBill: isSubscriptionBilling,
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
AccountRateMultiplier: accountRateMultiplier,
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
APIKeyService: input.APIKeyService,
|
APIKeyService: input.APIKeyService,
|
||||||
}, s.billingDeps())
|
}, s.billingDeps(), s.usageBillingRepo)
|
||||||
} else {
|
return err
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
}()
|
||||||
|
|
||||||
|
if billingErr != nil {
|
||||||
|
return billingErr
|
||||||
}
|
}
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct {
|
|||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||||
APIKeyService *APIKeyService // API Key 配额服务(可选)
|
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||||
|
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||||
@@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
imageSize = &result.ImageSize
|
imageSize = &result.ImageSize
|
||||||
}
|
}
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
|
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
@@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
usageLog.SubscriptionID = &subscription.ID
|
usageLog.SubscriptionID = &subscription.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldBill := inserted || err != nil
|
billingErr := func() error {
|
||||||
|
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||||
if shouldBill {
|
|
||||||
postUsageBilling(ctx, &postUsageBillingParams{
|
|
||||||
Cost: cost,
|
Cost: cost,
|
||||||
User: user,
|
User: user,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||||
IsSubscriptionBill: isSubscriptionBilling,
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
AccountRateMultiplier: accountRateMultiplier,
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
APIKeyService: input.APIKeyService,
|
APIKeyService: input.APIKeyService,
|
||||||
}, s.billingDeps())
|
}, s.billingDeps(), s.usageBillingRepo)
|
||||||
} else {
|
return err
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
}()
|
||||||
|
|
||||||
|
if billingErr != nil {
|
||||||
|
return billingErr
|
||||||
}
|
}
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
|
|||||||
|
|
||||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,35 +7,63 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type openAIRecordUsageLogRepoStub struct {
|
type openAIRecordUsageLogRepoStub struct {
|
||||||
UsageLogRepository
|
UsageLogRepository
|
||||||
|
|
||||||
inserted bool
|
inserted bool
|
||||||
err error
|
err error
|
||||||
calls int
|
calls int
|
||||||
lastLog *UsageLog
|
lastLog *UsageLog
|
||||||
|
lastCtxErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||||
s.calls++
|
s.calls++
|
||||||
s.lastLog = log
|
s.lastLog = log
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
return s.inserted, s.err
|
return s.inserted, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIRecordUsageBillingRepoStub struct {
|
||||||
|
UsageBillingRepository
|
||||||
|
|
||||||
|
result *UsageBillingApplyResult
|
||||||
|
err error
|
||||||
|
calls int
|
||||||
|
lastCmd *UsageBillingCommand
|
||||||
|
lastCtxErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
|
||||||
|
s.calls++
|
||||||
|
s.lastCmd = cmd
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
if s.result != nil {
|
||||||
|
return s.result, nil
|
||||||
|
}
|
||||||
|
return &UsageBillingApplyResult{Applied: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
type openAIRecordUsageUserRepoStub struct {
|
type openAIRecordUsageUserRepoStub struct {
|
||||||
UserRepository
|
UserRepository
|
||||||
|
|
||||||
deductCalls int
|
deductCalls int
|
||||||
deductErr error
|
deductErr error
|
||||||
lastAmount float64
|
lastAmount float64
|
||||||
|
lastCtxErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||||
s.deductCalls++
|
s.deductCalls++
|
||||||
s.lastAmount = amount
|
s.lastAmount = amount
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
return s.deductErr
|
return s.deductErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct {
|
|||||||
|
|
||||||
incrementCalls int
|
incrementCalls int
|
||||||
incrementErr error
|
incrementErr error
|
||||||
|
lastCtxErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||||
s.incrementCalls++
|
s.incrementCalls++
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
return s.incrementErr
|
return s.incrementErr
|
||||||
}
|
}
|
||||||
|
|
||||||
type openAIRecordUsageAPIKeyQuotaStub struct {
|
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||||
quotaCalls int
|
quotaCalls int
|
||||||
rateLimitCalls int
|
rateLimitCalls int
|
||||||
err error
|
err error
|
||||||
lastAmount float64
|
lastAmount float64
|
||||||
|
lastQuotaCtxErr error
|
||||||
|
lastRateLimitCtxErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||||
s.quotaCalls++
|
s.quotaCalls++
|
||||||
s.lastAmount = cost
|
s.lastAmount = cost
|
||||||
|
s.lastQuotaCtxErr = ctx.Err()
|
||||||
return s.err
|
return s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||||
s.rateLimitCalls++
|
s.rateLimitCalls++
|
||||||
s.lastAmount = cost
|
s.lastAmount = cost
|
||||||
|
s.lastRateLimitCtxErr = ctx.Err()
|
||||||
return s.err
|
return s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,23 +127,38 @@ func i64p(v int64) *int64 {
|
|||||||
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||||
cfg := &config.Config{}
|
cfg := &config.Config{}
|
||||||
cfg.Default.RateMultiplier = 1.1
|
cfg.Default.RateMultiplier = 1.1
|
||||||
|
svc := NewOpenAIGatewayService(
|
||||||
|
nil,
|
||||||
|
usageRepo,
|
||||||
|
nil,
|
||||||
|
userRepo,
|
||||||
|
subRepo,
|
||||||
|
rateRepo,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
NewBillingService(cfg, nil),
|
||||||
|
nil,
|
||||||
|
&BillingCacheService{},
|
||||||
|
nil,
|
||||||
|
&DeferredService{},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||||
|
rateRepo,
|
||||||
|
nil,
|
||||||
|
resolveUserGroupRateCacheTTL(cfg),
|
||||||
|
nil,
|
||||||
|
"service.openai_gateway.test",
|
||||||
|
)
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
return &OpenAIGatewayService{
|
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||||
usageLogRepo: usageRepo,
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||||
userRepo: userRepo,
|
svc.usageBillingRepo = billingRepo
|
||||||
userSubRepo: subRepo,
|
return svc
|
||||||
cfg: cfg,
|
|
||||||
billingService: NewBillingService(cfg, nil),
|
|
||||||
billingCacheService: &BillingCacheService{},
|
|
||||||
deferredService: &DeferredService{},
|
|
||||||
userGroupRateResolver: newUserGroupRateResolver(
|
|
||||||
rateRepo,
|
|
||||||
nil,
|
|
||||||
resolveUserGroupRateCacheTTL(cfg),
|
|
||||||
nil,
|
|
||||||
"service.openai_gateway.test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||||
@@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
|
|||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
Result: &OpenAIForwardResult{
|
Result: &OpenAIForwardResult{
|
||||||
@@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
|
|||||||
})
|
})
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
require.Equal(t, 1, usageRepo.calls)
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
require.Equal(t, 0, userRepo.deductCalls)
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
require.Equal(t, 0, subRepo.incrementCalls)
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_duplicate_billing_key",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 10045,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 20045},
|
||||||
|
Account: &Account{ID: 30045},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
|
||||||
|
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_usage_log_error",
|
||||||
|
Usage: usage,
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10041},
|
||||||
|
User: &User{ID: 20041},
|
||||||
|
Account: &Account{ID: 30041},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_not_persisted",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 10043,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 20043},
|
||||||
|
Account: &Account{ID: 30043},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||||
|
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_detached_billing_ctx",
|
||||||
|
Usage: usage,
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 10042,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 20042},
|
||||||
|
Account: &Account{ID: 30042},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.NoError(t, userRepo.lastCtxErr)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_detached_billing_repo_ctx",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10046},
|
||||||
|
User: &User{ID: 20046},
|
||||||
|
Account: &Account{ID: 30046},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.NoError(t, billingRepo.lastCtxErr)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.NoError(t, usageRepo.lastCtxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||||
|
|
||||||
|
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "openai_payload_hash",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "gpt-5",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
RequestPayloadHash: payloadHash,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
|
||||||
|
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10047},
|
||||||
|
User: &User{ID: 20047},
|
||||||
|
Account: &Account{ID: 30047},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_billing_fail",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10048},
|
||||||
|
User: &User{ID: 20048},
|
||||||
|
Account: &Account{ID: 30048},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 0, usageRepo.calls)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct {
|
|||||||
type OpenAIGatewayService struct {
|
type OpenAIGatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
usageLogRepo UsageLogRepository
|
usageLogRepo UsageLogRepository
|
||||||
|
usageBillingRepo UsageBillingRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
@@ -295,6 +296,7 @@ type OpenAIGatewayService struct {
|
|||||||
func NewOpenAIGatewayService(
|
func NewOpenAIGatewayService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
usageLogRepo UsageLogRepository,
|
usageLogRepo UsageLogRepository,
|
||||||
|
usageBillingRepo UsageBillingRepository,
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
userGroupRateRepo UserGroupRateRepository,
|
userGroupRateRepo UserGroupRateRepository,
|
||||||
@@ -312,6 +314,7 @@ func NewOpenAIGatewayService(
|
|||||||
svc := &OpenAIGatewayService{
|
svc := &OpenAIGatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
|
usageBillingRepo: usageBillingRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
@@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build upstream request
|
// Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
|
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
clientDisconnected := false
|
clientDisconnected := false
|
||||||
sawDone := false
|
sawDone := false
|
||||||
|
sawTerminalEvent := false
|
||||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
@@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
if trimmedData == "[DONE]" {
|
if trimmedData == "[DONE]" {
|
||||||
sawDone = true
|
sawDone = true
|
||||||
}
|
}
|
||||||
|
if openAIStreamEventIsTerminal(trimmedData) {
|
||||||
|
sawTerminalEvent = true
|
||||||
|
}
|
||||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
@@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
if clientDisconnected {
|
if sawTerminalEvent {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
}
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||||
|
}
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
logger.LegacyPrintf("service.openai_gateway",
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
|
||||||
account.ID,
|
|
||||||
upstreamRequestID,
|
|
||||||
err,
|
|
||||||
ctx.Err(),
|
|
||||||
)
|
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
|
||||||
}
|
}
|
||||||
if errors.Is(err, bufio.ErrTooLong) {
|
if errors.Is(err, bufio.ErrTooLong) {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||||
@@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
)
|
)
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||||
}
|
}
|
||||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||||
logger.FromContext(ctx).With(
|
logger.FromContext(ctx).With(
|
||||||
zap.String("component", "service.openai_gateway"),
|
zap.String("component", "service.openai_gateway"),
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.String("upstream_request_id", upstreamRequestID),
|
zap.String("upstream_request_id", upstreamRequestID),
|
||||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
@@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||||
|
sawTerminalEvent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent || clientDisconnected {
|
if errorEventSent || clientDisconnected {
|
||||||
return
|
return
|
||||||
@@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !sawTerminalEvent {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
}
|
}
|
||||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||||
if scanErr == nil {
|
if scanErr == nil {
|
||||||
return nil, nil, false
|
return nil, nil, false
|
||||||
}
|
}
|
||||||
|
if sawTerminalEvent {
|
||||||
|
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||||
|
return resultWithUsage(), nil, true
|
||||||
|
}
|
||||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||||
return resultWithUsage(), nil, true
|
|
||||||
}
|
}
|
||||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||||
return resultWithUsage(), nil, true
|
|
||||||
}
|
}
|
||||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||||
@@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
}
|
}
|
||||||
|
|
||||||
dataBytes := []byte(data)
|
dataBytes := []byte(data)
|
||||||
|
if openAIStreamEventIsTerminal(data) {
|
||||||
|
sawTerminalEvent = true
|
||||||
|
}
|
||||||
|
|
||||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||||
@@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||||
return resultWithUsage(), nil
|
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
@@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
|||||||
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
|
||||||
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
|
if len(data) < 72 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if gjson.GetBytes(data, "type").String() != "response.completed" {
|
eventType := gjson.GetBytes(data, "type").String()
|
||||||
|
if eventType != "response.completed" && eventType != "response.done" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
|||||||
|
|
||||||
// OpenAIRecordUsageInput input for recording usage
|
// OpenAIRecordUsageInput input for recording usage
|
||||||
type OpenAIRecordUsageInput struct {
|
type OpenAIRecordUsageInput struct {
|
||||||
Result *OpenAIForwardResult
|
Result *OpenAIForwardResult
|
||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription
|
Subscription *UserSubscription
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
APIKeyService APIKeyQuotaUpdater
|
RequestPayloadHash string
|
||||||
|
APIKeyService APIKeyQuotaUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsage records usage and deducts balance
|
// RecordUsage records usage and deducts balance
|
||||||
@@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
// Create usage log
|
// Create usage log
|
||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
|
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: requestID,
|
||||||
Model: billingModel,
|
Model: billingModel,
|
||||||
ServiceTier: result.ServiceTier,
|
ServiceTier: result.ServiceTier,
|
||||||
ReasoningEffort: result.ReasoningEffort,
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
@@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
usageLog.SubscriptionID = &subscription.ID
|
usageLog.SubscriptionID = &subscription.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldBill := inserted || err != nil
|
billingErr := func() error {
|
||||||
|
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||||
if shouldBill {
|
|
||||||
postUsageBilling(ctx, &postUsageBillingParams{
|
|
||||||
Cost: cost,
|
Cost: cost,
|
||||||
User: user,
|
User: user,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||||
IsSubscriptionBill: isSubscriptionBilling,
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
AccountRateMultiplier: accountRateMultiplier,
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
APIKeyService: input.APIKeyService,
|
APIKeyService: input.APIKeyService,
|
||||||
}, s.billingDeps())
|
}, s.billingDeps(), s.usageBillingRepo)
|
||||||
} else {
|
return err
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
}()
|
||||||
|
|
||||||
|
if billingErr != nil {
|
||||||
|
return billingErr
|
||||||
}
|
}
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Gateway: config.GatewayConfig{
|
Gateway: config.GatewayConfig{
|
||||||
@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
if err != nil {
|
if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
|
||||||
t.Fatalf("expected nil error, got %v", err)
|
t.Fatalf("expected incomplete stream error, got %v", err)
|
||||||
}
|
}
|
||||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
_ = pr.Close()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||||
|
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||||
|
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
require.Equal(t, 2, result.usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() { _ = pw.Close() }()
|
defer func() { _ = pw.Close() }()
|
||||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
|||||||
require.Equal(t, 3, usage.InputTokens)
|
require.Equal(t, 3, usage.InputTokens)
|
||||||
require.Equal(t, 5, usage.OutputTokens)
|
require.Equal(t, 5, usage.OutputTokens)
|
||||||
require.Equal(t, 2, usage.CacheReadInputTokens)
|
require.Equal(t, 2, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
// done 事件同样可能携带最终 usage
|
||||||
|
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
|
||||||
|
require.Equal(t, 13, usage.InputTokens)
|
||||||
|
require.Equal(t, 15, usage.OutputTokens)
|
||||||
|
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
||||||
|
|||||||
@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
110
backend/internal/service/usage_billing.go
Normal file
110
backend/internal/service/usage_billing.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
|
||||||
|
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
|
||||||
|
|
||||||
|
// UsageBillingCommand describes one billable request that must be applied at most once.
|
||||||
|
type UsageBillingCommand struct {
|
||||||
|
RequestID string
|
||||||
|
APIKeyID int64
|
||||||
|
RequestFingerprint string
|
||||||
|
RequestPayloadHash string
|
||||||
|
|
||||||
|
UserID int64
|
||||||
|
AccountID int64
|
||||||
|
SubscriptionID *int64
|
||||||
|
AccountType string
|
||||||
|
Model string
|
||||||
|
ServiceTier string
|
||||||
|
ReasoningEffort string
|
||||||
|
BillingType int8
|
||||||
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
CacheCreationTokens int
|
||||||
|
CacheReadTokens int
|
||||||
|
ImageCount int
|
||||||
|
MediaType string
|
||||||
|
|
||||||
|
BalanceCost float64
|
||||||
|
SubscriptionCost float64
|
||||||
|
APIKeyQuotaCost float64
|
||||||
|
APIKeyRateLimitCost float64
|
||||||
|
AccountQuotaCost float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UsageBillingCommand) Normalize() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.RequestID = strings.TrimSpace(c.RequestID)
|
||||||
|
if strings.TrimSpace(c.RequestFingerprint) == "" {
|
||||||
|
c.RequestFingerprint = buildUsageBillingFingerprint(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := fmt.Sprintf(
|
||||||
|
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
|
||||||
|
c.UserID,
|
||||||
|
c.AccountID,
|
||||||
|
c.APIKeyID,
|
||||||
|
strings.TrimSpace(c.AccountType),
|
||||||
|
strings.TrimSpace(c.Model),
|
||||||
|
strings.TrimSpace(c.ServiceTier),
|
||||||
|
strings.TrimSpace(c.ReasoningEffort),
|
||||||
|
c.BillingType,
|
||||||
|
c.InputTokens,
|
||||||
|
c.OutputTokens,
|
||||||
|
c.CacheCreationTokens,
|
||||||
|
c.CacheReadTokens,
|
||||||
|
c.ImageCount,
|
||||||
|
strings.TrimSpace(c.MediaType),
|
||||||
|
valueOrZero(c.SubscriptionID),
|
||||||
|
c.BalanceCost,
|
||||||
|
c.SubscriptionCost,
|
||||||
|
c.APIKeyQuotaCost,
|
||||||
|
c.APIKeyRateLimitCost,
|
||||||
|
c.AccountQuotaCost,
|
||||||
|
)
|
||||||
|
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
|
||||||
|
raw += "|" + payloadHash
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256([]byte(raw))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func HashUsageRequestPayload(payload []byte) string {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(payload)
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func valueOrZero(v *int64) int64 {
|
||||||
|
if v == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *v
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageBillingApplyResult struct {
|
||||||
|
Applied bool
|
||||||
|
APIKeyQuotaExhausted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageBillingRepository interface {
|
||||||
|
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
|
||||||
|
}
|
||||||
@@ -56,7 +56,8 @@ type cleanupRepoStub struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type dashboardRepoStub struct {
|
type dashboardRepoStub struct {
|
||||||
recomputeErr error
|
recomputeErr error
|
||||||
|
recomputeCalls int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
s.recomputeCalls++
|
||||||
return s.recomputeErr
|
return s.recomputeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||||
|
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
|
||||||
repo := &cleanupRepoStub{
|
repo := &cleanupRepoStub{
|
||||||
deleteQueue: []cleanupDeleteResponse{
|
deleteQueue: []cleanupDeleteResponse{
|
||||||
{deleted: 0},
|
{deleted: 0},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
|
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||||
})
|
})
|
||||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||||
@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
|||||||
repo.mu.Lock()
|
repo.mu.Lock()
|
||||||
defer repo.mu.Unlock()
|
defer repo.mu.Unlock()
|
||||||
require.Len(t, repo.markSucceeded, 1)
|
require.Len(t, repo.markSucceeded, 1)
|
||||||
|
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||||
|
dashboardRepo := &dashboardRepoStub{}
|
||||||
repo := &cleanupRepoStub{
|
repo := &cleanupRepoStub{
|
||||||
deleteQueue: []cleanupDeleteResponse{
|
deleteQueue: []cleanupDeleteResponse{
|
||||||
{deleted: 0},
|
{deleted: 0},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||||
})
|
})
|
||||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||||
@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
|||||||
repo.mu.Lock()
|
repo.mu.Lock()
|
||||||
defer repo.mu.Unlock()
|
defer repo.mu.Unlock()
|
||||||
require.Len(t, repo.markSucceeded, 1)
|
require.Len(t, repo.markSucceeded, 1)
|
||||||
|
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
||||||
|
|||||||
60
backend/internal/service/usage_log_create_result.go
Normal file
60
backend/internal/service/usage_log_create_result.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
type usageLogCreateDisposition int
|
||||||
|
|
||||||
|
const (
|
||||||
|
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
||||||
|
usageLogCreateDispositionNotPersisted
|
||||||
|
)
|
||||||
|
|
||||||
|
type UsageLogCreateError struct {
|
||||||
|
err error
|
||||||
|
disposition usageLogCreateDisposition
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UsageLogCreateError) Error() string {
|
||||||
|
if e == nil || e.err == nil {
|
||||||
|
return "usage log create error"
|
||||||
|
}
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UsageLogCreateError) Unwrap() error {
|
||||||
|
if e == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarkUsageLogCreateNotPersisted(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UsageLogCreateError{
|
||||||
|
err: err,
|
||||||
|
disposition: usageLogCreateDispositionNotPersisted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsUsageLogCreateNotPersisted(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var target *UsageLogCreateError
|
||||||
|
if !errors.As(err, &target) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return target.disposition == usageLogCreateDispositionNotPersisted
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
||||||
|
if inserted {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !IsUsageLogCreateNotPersisted(err)
|
||||||
|
}
|
||||||
13
backend/migrations/071_add_usage_billing_dedup.sql
Normal file
13
backend/migrations/071_add_usage_billing_dedup.sql
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来
|
||||||
|
-- 幂等执行:可重复运行
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_billing_dedup (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
request_id VARCHAR(255) NOT NULL,
|
||||||
|
api_key_id BIGINT NOT NULL,
|
||||||
|
request_fingerprint VARCHAR(64) NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key
|
||||||
|
ON usage_billing_dedup (request_id, api_key_id);
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
-- usage_billing_dedup 是按时间追加写入的幂等窄表。
|
||||||
|
-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。
|
||||||
|
-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。
|
||||||
|
|
||||||
|
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin
|
||||||
|
ON usage_billing_dedup
|
||||||
|
USING BRIN (created_at);
|
||||||
10
backend/migrations/073_add_usage_billing_dedup_archive.sql
Normal file
10
backend/migrations/073_add_usage_billing_dedup_archive.sql
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive (
|
||||||
|
request_id VARCHAR(255) NOT NULL,
|
||||||
|
api_key_id BIGINT NOT NULL,
|
||||||
|
request_fingerprint VARCHAR(64) NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL,
|
||||||
|
archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (request_id, api_key_id)
|
||||||
|
);
|
||||||
0
deploy/build_image.sh
Executable file → Normal file
0
deploy/build_image.sh
Executable file → Normal file
0
deploy/install-datamanagementd.sh
Executable file → Normal file
0
deploy/install-datamanagementd.sh
Executable file → Normal file
Reference in New Issue
Block a user