Compare commits

...

72 Commits

Author SHA1 Message Date
shaw
5f890e85e7 Merge PR #296: OAuth token cache invalidation on 401 and refresh 2026-01-15 16:31:37 +08:00
Wesley Liddick
10bc7f7042 Merge pull request #297 from LLLLLLiulei/feat/ip-management-enhancements
feat: add proxy geo location
2026-01-15 16:06:36 +08:00
yangjianbo
a65fd9dee8 Merge branch 'test' 2026-01-15 15:58:47 +08:00
yangjianbo
1bb4c76deb fix(账号管理): 移除调度切换后的冗余列表刷新
切换账号调度状态后,updateSchedulableInList 已完成局部更新,
无需再调用 load() 刷新整个列表。此修改减少不必要的 API 请求,
避免 UI 闪烁。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:57:47 +08:00
LLLLLLiulei
aab44f9fc8 feat: add proxy geo location 2026-01-15 15:15:20 +08:00
yangjianbo
0a848e7578 Merge branch 'test' 2026-01-15 15:15:07 +08:00
yangjianbo
90bce60b85 feat: merge dev 2026-01-15 15:14:44 +08:00
程序猿MT
c22d51ee41 Merge branch 'Wei-Shaw:main' into main 2026-01-15 15:12:16 +08:00
yangjianbo
a458e684bc fix(认证): OAuth 401 直接标记错误状态
- OAuth 401 清理缓存并设置错误状态

- 移除 oauth_401_cooldown_minutes 配置及示例

- 更新 401 相关单测

破坏性变更: OAuth 401 不再临时不可调度,需手动恢复
2026-01-15 15:06:34 +08:00
LLLLLLiulei
87b4662993 Revert "feat: add proxy geo location"
This reverts commit 09c4f82927ddce1c9528c146a26457f53d02b034.
2026-01-15 15:01:50 +08:00
LLLLLLiulei
3a100339b9 feat: add proxy geo location 2026-01-15 15:01:50 +08:00
Wesley Liddick
47eb3c8888 Merge pull request #284 from longgexx/main
fix(admin): 修复使用记录页面趋势图筛选联动和日期选择问题
2026-01-15 13:43:00 +08:00
longgexx
4672a6fac3 merge: 合并上游 upstream/main 分支
解决冲突:
- usage_log_repo.go: 保留 groupID 和 stream 参数,同时合并上游的
  account_rate_multiplier 逻辑

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 13:09:26 +08:00
longgexx
82743704e4 fix(test): 添加测试辅助函数 truncateToDayUTC 修复编译错误
在 usage_log_repo_integration_test.go 中添加本地的 truncateToDayUTC
辅助函数,修复因主代码重命名该函数导致的测试编译错误。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 11:42:35 +08:00
Wesley Liddick
cc2d064ab4 Merge pull request #289 from IanShaw027/fix/mobile-ui
fix(mobile): 优化移动端表格、操作栏和弹窗显示
2026-01-15 11:31:16 +08:00
Wesley Liddick
27214f8657 Merge pull request #285 from IanShaw027/fix/ops-bug
feat(ops): 增强错误日志管理、告警静默和前端 UI 优化
2026-01-15 11:26:16 +08:00
Wesley Liddick
28de614dfb Merge pull request #282 from LLLLLLiulei/feat/ip-management-enhancements
feat: enhance proxy management
2026-01-15 11:23:17 +08:00
longgexx
850183c269 fix(dashboard): 修复预聚合表使用UTC时区导致今日统计不准确的问题
将 dashboard_aggregation_repo.go 和 usage_log_repo.go 中的时区处理
从 UTC 改为使用服务器配置时区(默认 Asia/Shanghai),确保"今日"
统计数据与用户预期一致。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 11:22:13 +08:00
shaw
2a5ef6d3f5 Merge PR #279: feat(计费): 账号计费倍率快照与账号口径费用统计 2026-01-15 11:11:55 +08:00
IanShaw027
1d231c6cc3 fix(mobile): 修复 UsersView 更多菜单定位并统一逻辑
**问题描述**:
- UsersView 的"更多"菜单仍然出现在页面左上角错误位置
- UsersView 使用 actionButtonRefs Map 获取按钮元素,导致定位失败
- UsersView 和 AccountsView 的菜单定位逻辑不一致,难以维护

**解决方案**:
- 修改 openActionMenu 函数签名,添加 MouseEvent 参数
- 使用 e.currentTarget 直接从事件对象获取触发元素
- 移除不必要的 actionButtonRefs Map 和 setActionButtonRef 函数
- 统一菜单宽度为 200px(与 AccountsView 一致)
- 完全复制 AccountsView 的定位逻辑,确保两者行为一致

**技术要点**:
- 移动端:菜单居中对齐按钮,优先显示在按钮下方
- 桌面端:使用鼠标位置定位,添加边界检测
- 简化代码,移除不必要的防御性检查
- 两个组件的菜单定位逻辑完全一致,便于维护
2026-01-15 10:21:35 +08:00
IanShaw027
20c71acb3b fix(mobile): 优化移动端表格、操作栏和弹窗显示
**问题描述**:
- 表格在移动端显示列过多,需要横向滚动,内容被截断
- 顶部操作栏按钮拥挤,占用过多空间
- 弹窗表单在小屏幕上布局不合理
- "更多"操作菜单定位错误,位置过高或超出屏幕
- 滚动页面时菜单不会自动关闭,与卡片分离

**解决方案**:

1. **DataTable 组件 - 移动端卡片视图**
   - 在 < 768px 时自动切换到卡片布局
   - 每个表格行渲染为独立卡片,所有字段清晰可见
   - 操作按钮在卡片底部,触摸目标足够大
   - 支持深色模式,包含加载和空状态
   - 自动应用于所有使用 DataTable 的管理页面

2. **UsersView 顶部操作栏优化**
   - 移动端:搜索框全宽 + 次要按钮显示为图标 + 创建按钮突出
   - 桌面端:保持原有布局(图标 + 文字)
   - 使用响应式 Tailwind classes

3. **UserCreateModal 弹窗优化**
   - 余额/并发数字段:移动端单列,桌面端双列
   - 弹窗边距:移动端 8px,桌面端 16px

4. **操作菜单定位修复**
   - UsersView: 移动端菜单居中对齐按钮,智能定位
   - AccountsView: 移动端菜单优先显示在按钮下方
   - 所有情况下确保菜单不超出屏幕边界
   - 添加滚动监听,滚动时自动关闭菜单

**影响范围**:
- 所有使用 DataTable 的管理页面(8 个页面)自动获得移动端卡片视图
- 用户管理和账号管理页面的操作菜单定位优化
- 创建用户弹窗的响应式布局优化

**技术要点**:
- 使用 Tailwind 响应式断点(md:, sm:)
- 触摸目标 ≥ 44px
- 完整支持深色模式
- 向后兼容,桌面端保持原有布局
2026-01-15 10:08:14 +08:00
yangjianbo
52ad7c6e9c Merge branch 'main' into dev 2026-01-15 09:30:29 +08:00
longgexx
5aaaffe4d1 fix(dashboard): 修复仪表盘今日统计使用UTC时区的问题
将仪表盘统计中的"今日"时间范围从UTC时区改为服务器配置时区,
使其与使用记录页面保持一致。

修改内容:
- GetDashboardStats: 使用 timezone.Now() 和 timezone.Today()
- GetDashboardStatsWithRange: 同上

影响的统计项:
- 今日请求 (TodayRequests)
- 今日 Token (TodayTokens)
- 今日费用 (TodayCost/TodayActualCost)
- 今日新用户 (TodayNewUsers)
- 今日活跃用户 (ActiveUsers)
2026-01-15 09:12:16 +08:00
IanShaw027
5354ba3662 fix(ops): 修复错误列表用户显示并区分上游错误和请求错误
- 修复错误列表中用户列显示 \n 的问题
- 上游错误显示账号(account),请求错误显示用户(user)
- 错误详情模态框同步调整显示逻辑
- 添加 accountId 国际化翻译
2026-01-15 00:11:44 +08:00
IanShaw027
2daf13c4c8 style(backend): 修复 ops_service.go 代码格式 2026-01-15 00:04:37 +08:00
IanShaw027
16a90f3d3a fix(types): 在 OpsErrorLog 类型中添加 user_email 字段
- 修复 TypeScript 编译错误
- 添加 user_email 字段到 OpsErrorLog 接口
2026-01-15 00:03:34 +08:00
IanShaw027
8a0ff15242 feat(i18n): 添加用户相关的国际化翻译
- 在中英文 i18n 文件中添加 errorLog.user 和 errorLog.userId
- 在中英文 i18n 文件中添加 errorDetail.user
- 支持错误日志和详情中的用户信息显示
2026-01-15 00:02:32 +08:00
IanShaw027
8c993dfd35 refactor(frontend): 将账号显示替换为用户显示
- 在错误日志表格中将账号列替换为用户列
- 在错误详情模态框中将账号信息替换为用户信息
- 显示用户邮箱而不是账号名称
- 上游错误的账号信息保留在上游错误上下文中
2026-01-14 23:59:26 +08:00
IanShaw027
2a6fb1e456 feat(ops): 添加用户信息显示和搜索功能
- 在错误日志列表和详情中显示用户邮箱
- 在 GetErrorLogByID 中关联 users 表获取用户邮箱
- 在 OpsErrorLogFilter 中添加 UserQuery 字段
- 在 buildOpsErrorLogsWhere 中添加用户邮箱搜索条件
- 在 GetErrorLogs handler 中支持 user_query 参数
2026-01-14 23:56:45 +08:00
IanShaw027
9e6cd36af4 feat(ops): 添加上游响应体字段到错误事件
- 在 OpsUpstreamErrorEvent 中添加 UpstreamResponseBody 字段
- 用于存储上游服务返回的响应内容
- 区分客户端响应和上游响应
2026-01-14 23:52:36 +08:00
IanShaw027
f25f992a30 fix(ops): 错误详情中显示账号和分组名称
- 在 GetErrorLogByID 查询中添加 LEFT JOIN 关联查询
- 关联 accounts 和 groups 表获取名称
- 填充 AccountName 和 GroupName 字段
2026-01-14 23:52:00 +08:00
IanShaw027
841d7ef2f2 fix(lint): 修复 golangci-lint 检查问题
- 格式化代码(gofmt)
- 修复空指针检查(staticcheck)
- 删除未使用的函数(unused)
2026-01-14 23:49:27 +08:00
IanShaw027
a7a49be850 refactor(ops): 使用TTFT替代Duration作为健康分数指标
- 业务健康分数:错误率 50% + TTFT 50%
- TTFT 阈值:1s → 100分,3s → 0分
- TTFT 对 AI 服务的用户体验更有意义
- 更新所有相关测试用例期望值
2026-01-14 23:47:43 +08:00
IanShaw027
d5eab7da3b refactor(ops): 优化健康分数计算逻辑和阈值
- 移除 SLA 组件(与错误率重复)
- 恢复延迟组件,阈值调整为 1s-2s
- 错误率阈值调整为 1%-10%(更宽松)
- 业务健康分数:错误率 50% + 延迟 50%
- 更新所有相关测试用例期望值
2026-01-14 23:43:12 +08:00
IanShaw027
9b10241561 test(ops): 修复健康分数测试用例期望值
- 更新 TestComputeBusinessHealth 中 SLA 95% 边界测试的期望值
- 更新 TestComputeDashboardHealthScore 中中等健康度测试的期望值
- 适配移除延迟组件后的新健康分数计算逻辑
2026-01-14 23:39:09 +08:00
IanShaw027
76448ab555 refactor(frontend): 优化ops看板骨架屏组件
- 添加 fullscreen 属性支持,适配全屏模式
- 优化骨架屏布局,更好地匹配实际看板结构
- 改进加载动画效果,提升用户体验
2026-01-14 23:26:34 +08:00
IanShaw027
9584af5cb4 fix(ops): 优化错误日志查询和详情展示
- 新增 GetErrorLogByID 接口用于获取单个错误日志详情
- 优化 GetErrorLogs 过滤逻辑,简化参数处理
- 简化前端错误详情模态框代码,提升可维护性
- 更新相关 API 接口和 i18n 翻译
2026-01-14 23:16:01 +08:00
longgexx
6fabddcb0b fix(test): 更新集成测试以匹配新的筛选参数签名
更新 usage_log_repo_integration_test.go 中的测试用例,
   使其与 GetUsageTrendWithFilters 和 GetModelStatsWithFilters
   方法的新签名保持一致。
2026-01-14 22:29:14 +08:00
longgexx
5efeabb0c6 fix(admin): 修复使用记录页面趋势图筛选联动和日期选择问题
修复两个问题:
   1. Token使用趋势图和模型分布图未响应筛选条件
   2. 上午时段选择今天刷新后日期回退到前一天

   前端修改:
   - 更新 dashboard API 类型定义,添加 model、account_id、group_id、stream 参数支持
   - 修改 UsageView 趋势图加载逻辑,传递所有筛选参数到后端
   - 修复日期格式化函数,使用本地时区避免 UTC 转换导致的日期偏移

   后端修改:
   - Handler 层:接收并解析所有筛选参数(model、account_id、group_id、stream)
   - Service 层:传递完整的筛选参数到 Repository 层
   - Repository 层:SQL 查询动态添加所有过滤条件
   - 更新接口定义和测试 mock 以保持一致性

   影响范围:
   - /admin/dashboard/trend 端点现支持完整筛选
   - /admin/dashboard/models 端点现支持完整筛选
   - 用户在后台使用记录页面选择任意筛选条件时,趋势图和模型分布图会实时响应
   - 日期选择器在任何时区下都能正确保持今天的选择
2026-01-14 22:02:56 +08:00
longgexx
806f402bba fix(admin): 修复使用记录页面趋势图筛选联动和日期选择问题
修复两个问题:
   1. Token使用趋势图和模型分布图未响应筛选条件
   2. 上午时段选择今天刷新后日期回退到前一天

   前端修改:
   - 更新 dashboard API 类型定义,添加 model、account_id、group_id、stream 参数支持
   - 修改 UsageView 趋势图加载逻辑,传递所有筛选参数到后端
   - 修复日期格式化函数,使用本地时区避免 UTC 转换导致的日期偏移

   后端修改:
   - Handler 层:接收并解析所有筛选参数(model、account_id、group_id、stream)
   - Service 层:传递完整的筛选参数到 Repository 层
   - Repository 层:SQL 查询动态添加所有过滤条件
   - 更新接口定义和所有调用点以保持一致性

   影响范围:
   - /admin/dashboard/trend 端点现支持完整筛选
   - /admin/dashboard/models 端点现支持完整筛选
   - 用户在后台使用记录页面选择任意筛选条件时,趋势图和模型分布图会实时响应
   - 日期选择器在任何时区下都能正确保持今天的选择
2026-01-14 21:52:50 +08:00
IanShaw027
5432087d96 refactor(frontend): 优化ops错误详情模态框代码格式和功能
- 重构OpsErrorDetailModal.vue代码格式,提升可读性
- 添加上游错误tab显示功能
- 完善i18n翻译(upstream_http)
- 优化其他ops组件代码格式
2026-01-14 20:49:18 +08:00
LLLLLLiulei
02cb14c7b8 test: fix proxy repo stub 2026-01-14 19:56:19 +08:00
LLLLLLiulei
9bdb45be7c feat: enhance proxy management 2026-01-14 19:45:29 +08:00
IanShaw027
514c0562e0 refactor(frontend): 清理OpsDashboardHeader中的i18n翻译
将技术术语的i18n翻译键替换为硬编码文本:
- ms (P99) - 毫秒和百分位数标识
- TTFT - Time To First Token缩写

这些是通用技术术语,不需要国际化。
2026-01-14 19:02:02 +08:00
IanShaw027
371275ec34 refactor(frontend): 清理ops组件中未使用的i18n翻译
- 移除i18n文件中未使用的翻译键(cpu, redis, qps, ttft等)
- 将技术术语改为硬编码(QPS, CPU, TPS等不需要翻译)
- 简化OpsDashboardHeader、OpsErrorDetailModal等组件的i18n调用
2026-01-14 17:04:30 +08:00
墨颜
ec24a3c361 Merge branch 'Wei-Shaw:main' into main 2026-01-14 16:42:30 +08:00
墨颜
d89e797bfc Merge branch 'main' of https://github.com/whoismonay/sub2api 2026-01-14 16:30:32 +08:00
IanShaw027
55e469c7fe fix(ops): 优化错误日志过滤和查询逻辑
后端改动:
- 添加 resolved 参数默认值处理(向后兼容,默认显示未解决错误)
- 新增 status_codes_other 查询参数支持
- 移除 service 层的高级设置过滤逻辑,简化错误日志查询流程

前端改动:
- 完善错误日志相关组件的国际化支持
- 优化 Ops 监控面板和设置对话框的用户体验
2026-01-14 16:26:33 +08:00
墨颜
fb99ceacc7 feat(计费): 支持账号计费倍率快照与统计展示
- 新增 accounts.rate_multiplier(默认 1.0,允许 0)
- 使用 usage_logs.account_rate_multiplier 记录倍率快照,避免历史回算
- 统计/导出/管理端展示账号口径费用(total_cost * account_rate_multiplier)
2026-01-14 16:12:08 +08:00
yangjianbo
daf10907e4 fix(认证): 修复 OAuth token 缓存失效与 401 处理
新增 token 缓存失效接口并在刷新后清理
401 限流支持自定义规则与可配置冷却时间
补齐缓存失效与 401 处理测试
测试: make test
2026-01-14 15:55:44 +08:00
Wesley Liddick
b3b2868f55 Merge pull request #278 from IanShaw027/fix/openai-account-schedulability-check
fix(网关): 修复账号选择中的调度器快照延迟问题
2026-01-14 15:52:35 +08:00
ianshaw
25b00abca1 fix(网关): 修复账号选择中的调度器快照延迟问题
## 问题描述
调度器快照更新存在0.5-1秒的延迟(Outbox轮询间隔),导致在账号被限流或过载后的短时间窗口内,
可能仍会被选中,造成请求失败。

## 根本原因
账号选择逻辑依赖调度器快照(listSchedulableAccounts),但快照更新有延迟:
- Outbox轮询: 每1秒检查一次变更事件
- 全量重建: 每300秒重建一次
- 时间窗口: 账号状态变更后0.5-1秒内,快照可能未更新

## 解决方案
在账号选择循环中添加IsSchedulable()实时检查,作为第二道防线:
1. 第一道防线: 调度器快照过滤(可能有延迟)
2. 第二道防线: IsSchedulable()实时检查(本次修复)

IsSchedulable()会检查:
- RateLimitResetAt: 限流重置时间
- OverloadUntil: 过载持续时间
- TempUnschedulableUntil: 临时不可调度时间
- Status: 账号状态
- Schedulable: 可调度标志

## 修改范围
### OpenAI Gateway Service
- SelectAccountForModelWithExclusions: 添加IsSchedulable()检查
- SelectAccountWithLoadAwareness: 添加IsSchedulable()检查

### Gateway Service (Claude/Gemini/Antigravity)
- 负载感知选择候选账号筛选: 添加IsSchedulable()检查
- selectAccountForModelWithPlatform: 添加IsSchedulable()检查
- selectAccountWithMixedScheduling: 添加IsSchedulable()检查

### 测试用例
- OpenAI: 添加2个测试用例验证限流账号过滤
- Gateway: 添加2个测试用例验证限流和过载账号过滤

### 其他修复
- ops_repo_preagg.go: 修复platform为NULL时的聚合问题

## 测试结果
所有单元测试通过 
2026-01-13 22:49:26 -08:00
IanShaw027
8d0767352b fix(ops): 修复ops handler逻辑 2026-01-14 14:30:41 +08:00
IanShaw027
918a253851 feat(frontend): 完善ops监控面板和组件功能 2026-01-14 14:30:18 +08:00
IanShaw027
63711067e6 refactor(ops): 完善gateway服务ops集成 2026-01-14 14:30:00 +08:00
IanShaw027
7158b38897 refactor(ops): 优化ops repository数据访问层 2026-01-14 14:29:39 +08:00
IanShaw027
7f317b9093 feat(ops): 增强ops核心服务功能和重试机制 2026-01-14 14:29:19 +08:00
IanShaw027
7c4309ea24 feat(ops): 添加ops handler和路由配置 2026-01-14 14:29:01 +08:00
IanShaw027
5013290486 feat(frontend): 优化ops监控UI组件 2026-01-14 12:41:24 +08:00
IanShaw027
8cf3e9a620 feat(frontend): 更新ops API接口和国际化文案 2026-01-14 12:41:05 +08:00
IanShaw027
060699c3b8 refactor(ops): 更新gateway服务集成ops功能 2026-01-14 12:40:49 +08:00
IanShaw027
2ca6c631ac refactor(ops): 重构ops handler和repository层 2026-01-14 12:40:34 +08:00
IanShaw027
967e25878f refactor(ops): 重构ops核心服务层代码 2026-01-14 12:40:12 +08:00
IanShaw027
182683814b refactor(ops): 移除duration相关告警指标,简化监控配置
主要改动:
- 移除 p95_latency_ms 和 p99_latency_ms 告警指标类型
- 移除配置中的 latency_p99_ms_max 阈值设置
- 简化健康分数计算(移除latency权重,重新归一化SLA和错误率)
- 移除duration相关的诊断规则和阈值检查
- 统一术语:延迟 → 请求时长
- 保留duration数据展示,但不再用于告警判断
- 聚焦TTFT作为主要的响应速度告警指标

影响范围:
- Backend: handler, service, models, tests
- Frontend: API types, i18n, components
2026-01-14 10:52:56 +08:00
IanShaw027
33f58d583d fix(ops): 修复告警状态验证和错误处理逻辑
- 增强告警事件状态验证,添加合法状态值检查
- 移除重试逻辑中的遗留字段赋值
- 修正仓库不可用时的错误类型
- 格式化测试文件代码
2026-01-14 09:39:18 +08:00
IanShaw027
1e169685f4 feat(i18n): 添加ops新功能的国际化文案
- 新增告警静默相关的中英文翻译
- 补充错误分类和重试状态的文案
- 完善ops管理界面的提示信息
2026-01-14 09:04:09 +08:00
IanShaw027
f38a3e7585 feat(ui): 优化ops监控面板和组件功能
- 增强告警事件卡片的交互和静默功能
- 完善错误详情弹窗的展示和操作
- 优化错误日志表格的筛选和排序
- 新增重试和解决状态的UI支持
2026-01-14 09:03:59 +08:00
IanShaw027
b8da5d45ce feat(api): 扩展前端ops API接口
- 新增告警静默相关API调用
- 增强错误日志查询和过滤接口
- 添加重试和解决状态管理接口
2026-01-14 09:03:45 +08:00
IanShaw027
659df6e220 feat(handler): 新增ops管理接口和路由
- 添加告警静默管理接口
- 扩展错误日志查询和操作接口
- 新增重试和解决状态相关端点
- 完善错误日志记录功能
2026-01-14 09:03:35 +08:00
IanShaw027
d601768016 feat(service): 增强ops业务逻辑和告警功能
- 实现告警静默功能的业务逻辑
- 优化错误分类和重试机制
- 扩展告警评估和通知功能
- 完善错误解决和重试结果处理
2026-01-14 09:03:16 +08:00
IanShaw027
16ddc6a83b feat(repository): 扩展ops数据访问层功能
- 新增告警静默相关数据库操作
- 增强错误日志查询和统计功能
- 优化重试结果和解决状态的存储
2026-01-14 09:03:01 +08:00
IanShaw027
340dc9cadb feat(db): 添加ops告警静默和错误分类优化迁移
- 添加ops告警静默功能的数据库结构
- 优化错误分类和重试结果字段标准化
2026-01-14 09:02:45 +08:00
126 changed files with 8015 additions and 1712 deletions

2
.gitignore vendored
View File

@@ -83,6 +83,8 @@ temp/
*.log
*.bak
.cache/
.dev/
.serena/
# ===================
# 构建产物

164
PR_DESCRIPTION.md Normal file
View File

@@ -0,0 +1,164 @@
## 概述
全面增强运维监控系统Ops的错误日志管理和告警静默功能优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
## 主要改动
### 1. 错误日志查询优化
**功能特性:**
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
- 改进查询参数处理,简化代码结构
- 增强错误分类和标准化处理
- 支持错误解决状态追踪resolved 字段)
**技术实现:**
- `ops_handler.go` - 新增单条错误日志查询接口
- `ops_repo.go` - 优化数据查询和过滤条件构建
- `ops_models.go` - 扩展错误日志数据模型
- 前端 API 接口同步更新
### 2. 告警静默功能
**功能特性:**
- 支持按规则、平台、分组、区域等维度静默告警
- 可设置静默时长和原因说明
- 静默记录可追溯,记录创建人和创建时间
- 自动过期机制,避免永久静默
**技术实现:**
- `037_ops_alert_silences.sql` - 新增告警静默表
- `ops_alerts.go` - 告警静默逻辑实现
- `ops_alerts_handler.go` - 告警静默 API 接口
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
**数据库结构:**
| 字段 | 类型 | 说明 |
|------|------|------|
| rule_id | BIGINT | 告警规则 ID |
| platform | VARCHAR(64) | 平台标识 |
| group_id | BIGINT | 分组 ID可选 |
| region | VARCHAR(64) | 区域(可选) |
| until | TIMESTAMPTZ | 静默截止时间 |
| reason | TEXT | 静默原因 |
| created_by | BIGINT | 创建人 ID |
### 3. 错误分类标准化
**功能特性:**
- 统一错误阶段分类request|auth|routing|upstream|network|internal
- 规范错误归属分类client|provider|platform
- 标准化错误来源分类client_request|upstream_http|gateway
- 自动迁移历史数据到新分类体系
**技术实现:**
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
- 自动映射历史遗留分类到新标准
- 自动解决已恢复的上游错误(客户端状态码 < 400
### 4. Gateway 服务集成
**功能特性:**
- 完善各 Gateway 服务的 Ops 集成
- 统一错误日志记录接口
- 增强上游错误追踪能力
**涉及服务:**
- `antigravity_gateway_service.go` - Antigravity 网关集成
- `gateway_service.go` - 通用网关集成
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
- `openai_gateway_service.go` - OpenAI 网关集成
### 5. 前端 UI 优化
**代码重构:**
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
- 优化错误日志表格组件,提升可读性
- 清理未使用的 i18n 翻译,减少冗余
- 统一组件代码风格和格式
- 优化骨架屏组件,更好匹配实际看板布局
**布局改进:**
- 修复模态框内容溢出和滚动问题
- 优化表格布局,使用 flex 布局确保正确显示
- 改进看板头部布局和交互
- 提升响应式体验
- 骨架屏支持全屏模式适配
**交互优化:**
- 优化告警事件卡片功能和展示
- 改进错误详情展示逻辑
- 增强请求详情模态框
- 完善运行时设置卡片
- 改进加载动画效果
### 6. 国际化完善
**文案补充:**
- 补充错误日志相关的英文翻译
- 添加告警静默功能的中英文文案
- 完善提示文本和错误信息
- 统一术语翻译标准
## 文件变更
**后端26 个文件):**
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
- `backend/internal/service/ops_*.go` - 核心服务层重构10 个文件)
- `backend/internal/service/*_gateway_service.go` - Gateway 集成4 个文件)
- `backend/internal/server/routes/admin.go` - 路由配置更新
- `backend/migrations/*.sql` - 数据库迁移2 个文件)
- 测试文件更新5 个文件)
**前端13 个文件):**
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构10 个文件)
- `frontend/src/api/admin/ops.ts` - API 接口扩展
- `frontend/src/i18n/locales/*.ts` - 国际化文本2 个文件)
## 代码统计
- 44 个文件修改
- 3733 行新增
- 995 行删除
- 净增加 2738 行
## 核心改进
**可维护性提升:**
- 重构核心服务层,职责更清晰
- 简化前端组件代码,降低复杂度
- 统一代码风格和命名规范
- 清理冗余代码和未使用的翻译
- 标准化错误分类体系
**功能完善:**
- 告警静默功能,减少告警噪音
- 错误日志查询优化,提升运维效率
- Gateway 服务集成完善,统一监控能力
- 错误解决状态追踪,便于问题管理
**用户体验优化:**
- 修复多个 UI 布局问题
- 优化交互流程
- 完善国际化支持
- 提升响应式体验
- 改进加载状态展示
## 测试验证
- ✅ 错误日志查询和过滤功能
- ✅ 告警静默创建和自动过期
- ✅ 错误分类标准化迁移
- ✅ Gateway 服务错误日志记录
- ✅ 前端组件布局和交互
- ✅ 骨架屏全屏模式适配
- ✅ 国际化文本完整性
- ✅ API 接口功能正确性
- ✅ 数据库迁移执行成功

View File

@@ -67,7 +67,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
@@ -76,15 +75,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
timingWheelService := service.ProvideTimingWheelService()
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -98,12 +99,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
@@ -112,9 +114,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
@@ -125,6 +124,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -166,7 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{

View File

@@ -43,6 +43,8 @@ type Account struct {
Concurrency int `json:"concurrency,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// ErrorMessage holds the value of the "error_message" field.
@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
values[i] = new(sql.NullBool)
case account.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Priority = int(value.Int64)
}
case account.FieldRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case account.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -420,6 +430,9 @@ func (_m *Account) String() string {
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")

View File

@@ -39,6 +39,8 @@ const (
FieldConcurrency = "concurrency"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldErrorMessage holds the string denoting the error_message field in the database.
@@ -116,6 +118,7 @@ var Columns = []string{
FieldProxyID,
FieldConcurrency,
FieldPriority,
FieldRateMultiplier,
FieldStatus,
FieldErrorMessage,
FieldLastUsedAt,
@@ -174,6 +177,8 @@ var (
DefaultConcurrency int
// DefaultPriority holds the default value on creation for the "priority" field.
DefaultPriority int
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
DefaultRateMultiplier float64
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()
}
// ByRateMultiplier orders the results by the rate_multiplier field.
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()

View File

@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))
}
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
func RateMultiplier(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))
@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldPriority, v))
}
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
func RateMultiplierEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
func RateMultiplierNEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
}
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
func RateMultiplierIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
}
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
func RateMultiplierNotIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
}
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
func RateMultiplierGT(v float64) predicate.Account {
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
}
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
func RateMultiplierGTE(v float64) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
}
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
func RateMultiplierLT(v float64) predicate.Account {
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
}
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
func RateMultiplierLTE(v float64) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))

View File

@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
return _c
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
_c.mutation.SetRateMultiplier(v)
return _c
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
if v != nil {
_c.SetRateMultiplier(*v)
}
return _c
}
// SetStatus sets the "status" field.
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
_c.mutation.SetStatus(v)
@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
v := account.DefaultPriority
_c.mutation.SetPriority(v)
}
if _, ok := _c.mutation.RateMultiplier(); !ok {
v := account.DefaultRateMultiplier
_c.mutation.SetRateMultiplier(v)
}
if _, ok := _c.mutation.Status(); !ok {
v := account.DefaultStatus
_c.mutation.SetStatus(v)
@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
if _, ok := _c.mutation.Priority(); !ok {
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
}
if _, ok := _c.mutation.RateMultiplier(); !ok {
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
}
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
}
@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value
}
if value, ok := _c.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
return u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
u.Set(account.FieldRateMultiplier, v)
return u
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
u.SetExcluded(account.FieldRateMultiplier)
return u
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
u.Add(account.FieldRateMultiplier, v)
return u
}
// SetStatus sets the "status" field.
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
u.Set(account.FieldStatus, v)
@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
})
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field.
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
})
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field.
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
return _u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
_u.mutation.SetStatus(v)
@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}
@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
return _u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
_u.mutation.SetStatus(v)
@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}

View File

@@ -79,6 +79,7 @@ var (
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -101,7 +102,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[24]},
Columns: []*schema.Column{AccountsColumns[25]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -120,12 +121,12 @@ var (
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[12]},
Columns: []*schema.Column{AccountsColumns[13]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[24]},
Columns: []*schema.Column{AccountsColumns[25]},
},
{
Name: "account_priority",
@@ -135,27 +136,27 @@ var (
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[14]},
Columns: []*schema.Column{AccountsColumns[15]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[17]},
Columns: []*schema.Column{AccountsColumns[18]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[18]},
Columns: []*schema.Column{AccountsColumns[19]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[19]},
Columns: []*schema.Column{AccountsColumns[20]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[20]},
Columns: []*schema.Column{AccountsColumns[21]},
},
{
Name: "account_deleted_at",
@@ -449,6 +450,7 @@ var (
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
{Name: "stream", Type: field.TypeBool, Default: false},
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
@@ -472,31 +474,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -505,32 +507,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[26]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_model",
@@ -545,12 +547,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
},
},
}

View File

@@ -1187,6 +1187,8 @@ type AccountMutation struct {
addconcurrency *int
priority *int
addpriority *int
rate_multiplier *float64
addrate_multiplier *float64
status *string
error_message *string
last_used_at *time.Time
@@ -1822,6 +1824,62 @@ func (m *AccountMutation) ResetPriority() {
m.addpriority = nil
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (m *AccountMutation) SetRateMultiplier(f float64) {
m.rate_multiplier = &f
m.addrate_multiplier = nil
}
// RateMultiplier returns the value of the "rate_multiplier" field in the mutation.
func (m *AccountMutation) RateMultiplier() (r float64, exists bool) {
v := m.rate_multiplier
if v == nil {
return
}
return *v, true
}
// OldRateMultiplier returns the old "rate_multiplier" field's value of the Account entity.
// If the Account object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AccountMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldRateMultiplier requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err)
}
return oldValue.RateMultiplier, nil
}
// AddRateMultiplier adds f to the "rate_multiplier" field.
func (m *AccountMutation) AddRateMultiplier(f float64) {
if m.addrate_multiplier != nil {
*m.addrate_multiplier += f
} else {
m.addrate_multiplier = &f
}
}
// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation.
func (m *AccountMutation) AddedRateMultiplier() (r float64, exists bool) {
v := m.addrate_multiplier
if v == nil {
return
}
return *v, true
}
// ResetRateMultiplier resets all changes to the "rate_multiplier" field.
func (m *AccountMutation) ResetRateMultiplier() {
m.rate_multiplier = nil
m.addrate_multiplier = nil
}
// SetStatus sets the "status" field.
func (m *AccountMutation) SetStatus(s string) {
m.status = &s
@@ -2540,7 +2598,7 @@ func (m *AccountMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AccountMutation) Fields() []string {
fields := make([]string, 0, 24)
fields := make([]string, 0, 25)
if m.created_at != nil {
fields = append(fields, account.FieldCreatedAt)
}
@@ -2577,6 +2635,9 @@ func (m *AccountMutation) Fields() []string {
if m.priority != nil {
fields = append(fields, account.FieldPriority)
}
if m.rate_multiplier != nil {
fields = append(fields, account.FieldRateMultiplier)
}
if m.status != nil {
fields = append(fields, account.FieldStatus)
}
@@ -2645,6 +2706,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
return m.Concurrency()
case account.FieldPriority:
return m.Priority()
case account.FieldRateMultiplier:
return m.RateMultiplier()
case account.FieldStatus:
return m.Status()
case account.FieldErrorMessage:
@@ -2702,6 +2765,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldConcurrency(ctx)
case account.FieldPriority:
return m.OldPriority(ctx)
case account.FieldRateMultiplier:
return m.OldRateMultiplier(ctx)
case account.FieldStatus:
return m.OldStatus(ctx)
case account.FieldErrorMessage:
@@ -2819,6 +2884,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
}
m.SetPriority(v)
return nil
case account.FieldRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetRateMultiplier(v)
return nil
case account.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -2917,6 +2989,9 @@ func (m *AccountMutation) AddedFields() []string {
if m.addpriority != nil {
fields = append(fields, account.FieldPriority)
}
if m.addrate_multiplier != nil {
fields = append(fields, account.FieldRateMultiplier)
}
return fields
}
@@ -2929,6 +3004,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedConcurrency()
case account.FieldPriority:
return m.AddedPriority()
case account.FieldRateMultiplier:
return m.AddedRateMultiplier()
}
return nil, false
}
@@ -2952,6 +3029,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
}
m.AddPriority(v)
return nil
case account.FieldRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddRateMultiplier(v)
return nil
}
return fmt.Errorf("unknown Account numeric field %s", name)
}
@@ -3090,6 +3174,9 @@ func (m *AccountMutation) ResetField(name string) error {
case account.FieldPriority:
m.ResetPriority()
return nil
case account.FieldRateMultiplier:
m.ResetRateMultiplier()
return nil
case account.FieldStatus:
m.ResetStatus()
return nil
@@ -10190,6 +10277,8 @@ type UsageLogMutation struct {
addactual_cost *float64
rate_multiplier *float64
addrate_multiplier *float64
account_rate_multiplier *float64
addaccount_rate_multiplier *float64
billing_type *int8
addbilling_type *int8
stream *bool
@@ -11323,6 +11412,76 @@ func (m *UsageLogMutation) ResetRateMultiplier() {
m.addrate_multiplier = nil
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (m *UsageLogMutation) SetAccountRateMultiplier(f float64) {
m.account_rate_multiplier = &f
m.addaccount_rate_multiplier = nil
}
// AccountRateMultiplier returns the value of the "account_rate_multiplier" field in the mutation.
func (m *UsageLogMutation) AccountRateMultiplier() (r float64, exists bool) {
v := m.account_rate_multiplier
if v == nil {
return
}
return *v, true
}
// OldAccountRateMultiplier returns the old "account_rate_multiplier" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldAccountRateMultiplier(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAccountRateMultiplier is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAccountRateMultiplier requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAccountRateMultiplier: %w", err)
}
return oldValue.AccountRateMultiplier, nil
}
// AddAccountRateMultiplier adds f to the "account_rate_multiplier" field.
func (m *UsageLogMutation) AddAccountRateMultiplier(f float64) {
if m.addaccount_rate_multiplier != nil {
*m.addaccount_rate_multiplier += f
} else {
m.addaccount_rate_multiplier = &f
}
}
// AddedAccountRateMultiplier returns the value that was added to the "account_rate_multiplier" field in this mutation.
func (m *UsageLogMutation) AddedAccountRateMultiplier() (r float64, exists bool) {
v := m.addaccount_rate_multiplier
if v == nil {
return
}
return *v, true
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (m *UsageLogMutation) ClearAccountRateMultiplier() {
m.account_rate_multiplier = nil
m.addaccount_rate_multiplier = nil
m.clearedFields[usagelog.FieldAccountRateMultiplier] = struct{}{}
}
// AccountRateMultiplierCleared returns if the "account_rate_multiplier" field was cleared in this mutation.
func (m *UsageLogMutation) AccountRateMultiplierCleared() bool {
_, ok := m.clearedFields[usagelog.FieldAccountRateMultiplier]
return ok
}
// ResetAccountRateMultiplier resets all changes to the "account_rate_multiplier" field.
func (m *UsageLogMutation) ResetAccountRateMultiplier() {
m.account_rate_multiplier = nil
m.addaccount_rate_multiplier = nil
delete(m.clearedFields, usagelog.FieldAccountRateMultiplier)
}
// SetBillingType sets the "billing_type" field.
func (m *UsageLogMutation) SetBillingType(i int8) {
m.billing_type = &i
@@ -11963,7 +12122,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 29)
fields := make([]string, 0, 30)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -12024,6 +12183,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.rate_multiplier != nil {
fields = append(fields, usagelog.FieldRateMultiplier)
}
if m.account_rate_multiplier != nil {
fields = append(fields, usagelog.FieldAccountRateMultiplier)
}
if m.billing_type != nil {
fields = append(fields, usagelog.FieldBillingType)
}
@@ -12099,6 +12261,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ActualCost()
case usagelog.FieldRateMultiplier:
return m.RateMultiplier()
case usagelog.FieldAccountRateMultiplier:
return m.AccountRateMultiplier()
case usagelog.FieldBillingType:
return m.BillingType()
case usagelog.FieldStream:
@@ -12166,6 +12330,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldActualCost(ctx)
case usagelog.FieldRateMultiplier:
return m.OldRateMultiplier(ctx)
case usagelog.FieldAccountRateMultiplier:
return m.OldAccountRateMultiplier(ctx)
case usagelog.FieldBillingType:
return m.OldBillingType(ctx)
case usagelog.FieldStream:
@@ -12333,6 +12499,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetRateMultiplier(v)
return nil
case usagelog.FieldAccountRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAccountRateMultiplier(v)
return nil
case usagelog.FieldBillingType:
v, ok := value.(int8)
if !ok {
@@ -12443,6 +12616,9 @@ func (m *UsageLogMutation) AddedFields() []string {
if m.addrate_multiplier != nil {
fields = append(fields, usagelog.FieldRateMultiplier)
}
if m.addaccount_rate_multiplier != nil {
fields = append(fields, usagelog.FieldAccountRateMultiplier)
}
if m.addbilling_type != nil {
fields = append(fields, usagelog.FieldBillingType)
}
@@ -12489,6 +12665,8 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedActualCost()
case usagelog.FieldRateMultiplier:
return m.AddedRateMultiplier()
case usagelog.FieldAccountRateMultiplier:
return m.AddedAccountRateMultiplier()
case usagelog.FieldBillingType:
return m.AddedBillingType()
case usagelog.FieldDurationMs:
@@ -12597,6 +12775,13 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
}
m.AddRateMultiplier(v)
return nil
case usagelog.FieldAccountRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddAccountRateMultiplier(v)
return nil
case usagelog.FieldBillingType:
v, ok := value.(int8)
if !ok {
@@ -12639,6 +12824,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldSubscriptionID) {
fields = append(fields, usagelog.FieldSubscriptionID)
}
if m.FieldCleared(usagelog.FieldAccountRateMultiplier) {
fields = append(fields, usagelog.FieldAccountRateMultiplier)
}
if m.FieldCleared(usagelog.FieldDurationMs) {
fields = append(fields, usagelog.FieldDurationMs)
}
@@ -12674,6 +12862,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldSubscriptionID:
m.ClearSubscriptionID()
return nil
case usagelog.FieldAccountRateMultiplier:
m.ClearAccountRateMultiplier()
return nil
case usagelog.FieldDurationMs:
m.ClearDurationMs()
return nil
@@ -12757,6 +12948,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldRateMultiplier:
m.ResetRateMultiplier()
return nil
case usagelog.FieldAccountRateMultiplier:
m.ResetAccountRateMultiplier()
return nil
case usagelog.FieldBillingType:
m.ResetBillingType()
return nil

View File

@@ -177,22 +177,26 @@ func init() {
accountDescPriority := accountFields[8].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
accountDescRateMultiplier := accountFields[9].Descriptor()
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
// accountDescStatus is the schema descriptor for status field.
accountDescStatus := accountFields[9].Descriptor()
accountDescStatus := accountFields[10].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
// accountDescSchedulable is the schema descriptor for schedulable field.
accountDescSchedulable := accountFields[14].Descriptor()
accountDescSchedulable := accountFields[15].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
accountDescSessionWindowStatus := accountFields[20].Descriptor()
accountDescSessionWindowStatus := accountFields[21].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()
@@ -578,31 +582,31 @@ func init() {
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[20].Descriptor()
usagelogDescBillingType := usagelogFields[21].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[21].Descriptor()
usagelogDescStream := usagelogFields[22].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[24].Descriptor()
usagelogDescUserAgent := usagelogFields[25].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[25].Descriptor()
usagelogDescIPAddress := usagelogFields[26].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[26].Descriptor()
usagelogDescImageCount := usagelogFields[27].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[27].Descriptor()
usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -102,6 +102,12 @@ func (Account) Fields() []ent.Field {
field.Int("priority").
Default(50),
// rate_multiplier: 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
field.Float("rate_multiplier").
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
Default(1.0),
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).

View File

@@ -85,6 +85,12 @@ func (UsageLog) Fields() []ent.Field {
Default(1).
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// account_rate_multiplier: 账号计费倍率快照NULL 表示按 1.0 处理)
field.Float("account_rate_multiplier").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// 其他字段
field.Int8("billing_type").
Default(0),

View File

@@ -62,6 +62,8 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
// BillingType holds the value of the "billing_type" field.
BillingType int8 `json:"billing_type,omitempty"`
// Stream holds the value of the "stream" field.
@@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
@@ -316,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case usagelog.FieldAccountRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
} else if value.Valid {
_m.AccountRateMultiplier = new(float64)
*_m.AccountRateMultiplier = value.Float64
}
case usagelog.FieldBillingType:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
@@ -500,6 +509,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
if v := _m.AccountRateMultiplier; v != nil {
builder.WriteString("account_rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("billing_type=")
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
builder.WriteString(", ")

View File

@@ -54,6 +54,8 @@ const (
FieldActualCost = "actual_cost"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
FieldAccountRateMultiplier = "account_rate_multiplier"
// FieldBillingType holds the string denoting the billing_type field in the database.
FieldBillingType = "billing_type"
// FieldStream holds the string denoting the stream field in the database.
@@ -144,6 +146,7 @@ var Columns = []string{
FieldTotalCost,
FieldActualCost,
FieldRateMultiplier,
FieldAccountRateMultiplier,
FieldBillingType,
FieldStream,
FieldDurationMs,
@@ -320,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
}
// ByBillingType orders the results by the billing_type field.
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingType, opts...).ToFunc()

View File

@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
}
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
func AccountRateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
func BillingType(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
@@ -970,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
}
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
}
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
}
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
func BillingTypeEQ(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))

View File

@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
return _c
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
_c.mutation.SetAccountRateMultiplier(v)
return _c
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
if v != nil {
_c.SetAccountRateMultiplier(*v)
}
return _c
}
// SetBillingType sets the "billing_type" field.
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
_c.mutation.SetBillingType(v)
@@ -712,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
_node.AccountRateMultiplier = &value
}
if value, ok := _c.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
_node.BillingType = value
@@ -1215,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
return u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Set(usagelog.FieldAccountRateMultiplier, v)
return u
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
return u
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Add(usagelog.FieldAccountRateMultiplier, v)
return u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
u.SetNull(usagelog.FieldAccountRateMultiplier)
return u
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
u.Set(usagelog.FieldBillingType, v)
@@ -1795,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2566,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
_u.mutation.ResetBillingType()
@@ -807,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
@@ -1406,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
_u.mutation.ResetBillingType()
@@ -1828,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}

View File

@@ -84,6 +84,7 @@ type CreateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -101,6 +102,7 @@ type UpdateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
@@ -115,6 +117,7 @@ type BulkUpdateAccountsRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
@@ -199,6 +202,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -213,6 +220,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
@@ -258,6 +266,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -271,6 +283,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
RateMultiplier: req.RateMultiplier,
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
@@ -652,6 +665,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -660,6 +677,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.RateMultiplier != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
@@ -677,6 +695,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,

View File

@@ -186,13 +186,16 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
// Parse optional filter params
var userID, apiKeyID int64
var userID, apiKeyID, accountID, groupID int64
var model string
var stream *bool
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -203,8 +206,26 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -220,12 +241,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID int64
var userID, apiKeyID, accountID, groupID int64
var stream *bool
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -236,8 +259,23 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return

View File

@@ -7,8 +7,10 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
@@ -18,8 +20,6 @@ var validOpsAlertMetricTypes = []string{
"success_rate",
"error_rate",
"upstream_error_rate",
"p95_latency_ms",
"p99_latency_ms",
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
@@ -372,8 +372,135 @@ func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
response.Success(c, gin.H{"deleted": true})
}
// GetAlertEvent returns a single ops alert event.
// GET /api/v1/admin/ops/alert-events/:id
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ev)
}
// UpdateAlertEventStatus updates an ops alert event status.
// PUT /api/v1/admin/ops/alert-events/:id/status
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
var payload struct {
Status string `json:"status"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
payload.Status = strings.TrimSpace(payload.Status)
if payload.Status == "" {
response.BadRequest(c, "Invalid status")
return
}
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
response.BadRequest(c, "Invalid status")
return
}
var resolvedAt *time.Time
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
now := time.Now().UTC()
resolvedAt = &now
}
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"updated": true})
}
// ListAlertEvents lists recent ops alert events.
// GET /api/v1/admin/ops/alert-events
// CreateAlertSilence creates a scoped silence for ops alerts.
// POST /api/v1/admin/ops/alert-silences
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var payload struct {
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
Region *string `json:"region"`
Until string `json:"until"`
Reason string `json:"reason"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
if err != nil {
response.BadRequest(c, "Invalid until")
return
}
createdBy := (*int64)(nil)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
uid := subject.UserID
createdBy = &uid
}
silence := &service.OpsAlertSilence{
RuleID: payload.RuleID,
Platform: strings.TrimSpace(payload.Platform),
GroupID: payload.GroupID,
Region: payload.Region,
Until: until,
Reason: strings.TrimSpace(payload.Reason),
CreatedBy: createdBy,
}
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
@@ -384,7 +511,7 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
return
}
limit := 100
limit := 20
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
@@ -400,6 +527,49 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
Severity: strings.TrimSpace(c.Query("severity")),
}
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
vv := strings.ToLower(v)
switch vv {
case "true", "1":
b := true
filter.EmailSent = &b
case "false", "0":
b := false
filter.EmailSent = &b
default:
response.BadRequest(c, "Invalid email_sent")
return
}
}
// Cursor pagination: both params must be provided together.
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
rawID := strings.TrimSpace(c.Query("before_id"))
if (rawTS == "") != (rawID == "") {
response.BadRequest(c, "before_fired_at and before_id must be provided together")
return
}
if rawTS != "" {
ts, err := time.Parse(time.RFC3339Nano, rawTS)
if err != nil {
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
ts = t2
} else {
response.BadRequest(c, "Invalid before_fired_at")
return
}
}
filter.BeforeFiredAt = &ts
}
if rawID != "" {
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid before_id")
return
}
filter.BeforeID = &id
}
// Optional global filter support (platform/group/time range).
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform

View File

@@ -19,6 +19,57 @@ type OpsHandler struct {
opsService *service.OpsService
}
// GetErrorLogByID returns ops error log detail.
// GET /api/v1/admin/ops/errors/:id
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
const (
opsListViewErrors = "errors"
opsListViewExcluded = "excluded"
opsListViewAll = "all"
)
func parseOpsViewParam(c *gin.Context) string {
if c == nil {
return ""
}
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
switch v {
case "", opsListViewErrors:
return opsListViewErrors
case opsListViewExcluded:
return opsListViewExcluded
case opsListViewAll:
return opsListViewAll
default:
return opsListViewErrors
}
}
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
}
@@ -47,16 +98,26 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
return
}
filter := &service.OpsErrorLogFilter{
Page: page,
PageSize: pageSize,
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
@@ -77,11 +138,19 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
}
filter.AccountID = &id
}
if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
filter.Phase = phase
}
if q := strings.TrimSpace(c.Query("q")); q != "" {
filter.Query = q
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
@@ -106,13 +175,120 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetErrorLogByID returns a single error log detail.
// GET /api/v1/admin/ops/errors/:id
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
// ListRequestErrors lists client-visible request errors.
// GET /api/v1/admin/ops/request-errors
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetRequestError returns request error detail.
// GET /api/v1/admin/ops/request-errors/:id
func (h *OpsHandler) GetRequestError(c *gin.Context) {
// same storage; just proxy to existing detail
h.GetErrorLogByID(c)
}
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
@@ -129,15 +305,306 @@ func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
return
}
// Load request error to get correlation keys.
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
// Correlate by request_id/client_request_id.
requestID := strings.TrimSpace(detail.RequestID)
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
if requestID == "" && clientRequestID == "" {
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
// Keep correlation window wide enough so linked upstream errors
// are discoverable even when UI defaults to 1h elsewhere.
startTime, endTime, err := parseOpsTimeRange(c, "30d")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = "all"
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
// Prefer exact match on request_id; if missing, fall back to client_request_id.
if requestID != "" {
filter.RequestID = requestID
} else {
filter.ClientRequestID = clientRequestID
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
// If client asks for details, expand each upstream error log to include upstream response fields.
includeDetail := strings.TrimSpace(c.Query("include_detail"))
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
for _, item := range result.Errors {
if item == nil {
continue
}
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
if err != nil || d == nil {
continue
}
details = append(details, d)
}
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// RetryRequestErrorClient retries the client request based on stored request body.
// POST /api/v1/admin/ops/request-errors/:id/retry-client
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
idxStr := strings.TrimSpace(c.Param("idx"))
idx, err := strconv.Atoi(idxStr)
if err != nil || idx < 0 {
response.BadRequest(c, "Invalid upstream idx")
return
}
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ListUpstreamErrors lists independent upstream errors.
// GET /api/v1/admin/ops/upstream-errors
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetUpstreamError returns upstream error detail.
// GET /api/v1/admin/ops/upstream-errors/:id
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
h.GetErrorLogByID(c)
}
// RetryUpstreamError retries upstream error using the original account_id.
// POST /api/v1/admin/ops/upstream-errors/:id/retry
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ==================== Existing endpoints ====================
// ListRequestDetails returns a request-level list (success + error) for drill-down.
// GET /api/v1/admin/ops/requests
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
@@ -242,6 +709,11 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
type opsRetryRequest struct {
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
Force bool `json:"force"`
}
type opsResolveRequest struct {
Resolved bool `json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
@@ -278,6 +750,16 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
req.Mode = service.OpsRetryModeClient
}
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
_ = req.Force
// Legacy endpoint safety: only allow retrying the client request here.
// Upstream retries must go through the split endpoints.
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
response.BadRequest(c, "upstream retry is not supported on this endpoint")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
if err != nil {
response.ErrorFrom(c, err)
@@ -287,6 +769,81 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
response.Success(c, result)
}
// ListRetryAttempts lists retry attempts for an error log.
// GET /api/v1/admin/ops/errors/:id/retries
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
limit := 50
if v := strings.TrimSpace(c.Query("limit")); v != "" {
n, err := strconv.Atoi(v)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, items)
}
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
var req opsResolveRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
uid := subject.UserID
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": true})
}
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
startStr := strings.TrimSpace(c.Query("start_time"))
endStr := strings.TrimSpace(c.Query("end_time"))
@@ -358,6 +915,10 @@ func parseOpsDuration(v string) (time.Duration, bool) {
return 6 * time.Hour, true
case "24h":
return 24 * time.Hour, true
case "7d":
return 7 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}

View File

@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
// BatchDelete handles batch deleting proxies
// POST /api/v1/admin/proxies/batch-delete
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
type BatchDeleteRequest struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
var req BatchDeleteRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
return
}
page, pageSize := response.ParsePagination(c)
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Account, 0, len(accounts))
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.AccountFromService(&accounts[i]))
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
}
response.Paginated(c, out, total, page, pageSize)
response.Success(c, out)
}
// BatchCreateProxyItem represents a single proxy in batch create request

View File

@@ -125,6 +125,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
RateMultiplier: a.BillingRateMultiplier(),
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
@@ -212,8 +213,29 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
LatencyMs: p.LatencyMs,
LatencyStatus: p.LatencyStatus,
LatencyMessage: p.LatencyMessage,
IPAddress: p.IPAddress,
Country: p.Country,
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
}
}
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil {
return nil
}
return &ProxyAccountSummary{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Notes: a.Notes,
}
}
@@ -279,6 +301,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
AccountRateMultiplier: l.AccountRateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,

View File

@@ -76,6 +76,7 @@ type Account struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
@@ -129,7 +130,23 @@ type Proxy struct {
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
AccountCount int64 `json:"account_count"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
LatencyStatus string `json:"latency_status,omitempty"`
LatencyMessage string `json:"latency_message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
}
type ProxyAccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Notes *string `json:"notes,omitempty"`
}
type RedeemCode struct {
@@ -169,13 +186,14 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`

View File

@@ -544,6 +544,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body)
// Skip logging if the error should be filtered based on settings
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
return
}
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
@@ -832,28 +837,30 @@ func normalizeOpsErrorType(errType string, code string) string {
func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
return "billing"
return "request"
}
switch errType {
case "authentication_error":
return "auth"
case "billing_error", "subscription_error":
return "billing"
return "request"
case "rate_limit_error":
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
return "concurrency"
return "request"
}
return "upstream"
case "invalid_request_error":
return "response"
return "request"
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
return "scheduling"
return "routing"
}
return "internal"
default:
@@ -914,34 +921,38 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
}
func classifyOpsErrorOwner(phase string, message string) string {
// Standardized owners: client|provider|platform
switch phase {
case "upstream", "network":
return "provider"
case "billing", "concurrency", "auth", "response":
case "request", "auth":
return "client"
case "routing", "internal":
return "platform"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "provider"
}
return "sub2api"
return "platform"
}
}
func classifyOpsErrorSource(phase string, message string) string {
// Standardized sources: client_request|upstream_http|gateway
switch phase {
case "upstream":
return "upstream_http"
case "network":
return "upstream_network"
case "billing":
return "billing"
case "concurrency":
return "concurrency"
return "gateway"
case "request", "auth":
return "client_request"
case "routing", "internal":
return "gateway"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "upstream_http"
}
return "internal"
return "gateway"
}
}
@@ -963,3 +974,42 @@ func truncateString(s string, max int) string {
func strconvItoa(v int) string {
return strconv.Itoa(v)
}
// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
// Returns true for errors that should be filtered according to OpsAdvancedSettings.
func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
if ops == nil {
return false
}
// Get advanced settings to check filter configuration
settings, err := ops.GetOpsAdvancedSettings(ctx)
if err != nil || settings == nil {
// If we can't get settings, don't skip (fail open)
return false
}
msgLower := strings.ToLower(message)
bodyLower := strings.ToLower(body)
// Check if count_tokens errors should be ignored
if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
return true
}
// Check if context canceled errors should be ignored (client disconnects)
if settings.IgnoreContextCanceled {
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
return true
}
}
// Check if "no available accounts" errors should be ignored
if settings.IgnoreNoAvailableAccounts {
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
return true
}
}
return false
}

View File

@@ -1,8 +1,14 @@
package usagestats
// AccountStats 账号使用统计
//
// cost: 账号口径费用(使用 total_cost * account_rate_multiplier
// standard_cost: 标准费用(使用 total_cost不含倍率
// user_cost: 用户/API Key 口径费用(使用 actual_cost受分组倍率影响
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
}

View File

@@ -147,14 +147,15 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
@@ -177,25 +178,29 @@ type AccountUsageHistory struct {
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
Cost float64 `json:"cost"` // 标准计费total_cost
ActualCost float64 `json:"actual_cost"` // 账号口径费用total_cost * account_rate_multiplier
UserCost float64 `json:"user_cost"` // 用户口径费用actual_cost受分组倍率影响
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalCost float64 `json:"total_cost"` // 账号口径费用
TotalUserCost float64 `json:"total_user_cost"` // 用户口径费用
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyCost float64 `json:"avg_daily_cost"` // 账号口径日均
AvgDailyUserCost float64 `json:"avg_daily_user_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
@@ -203,6 +208,7 @@ type AccountUsageSummary struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
@@ -210,6 +216,7 @@ type AccountUsageSummary struct {
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
} `json:"highest_request_day"`
}

View File

@@ -80,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
@@ -291,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
@@ -999,6 +1007,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority)
idx++
}
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
@@ -1347,6 +1360,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return nil
}
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
@@ -1358,6 +1373,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
startUTC := start.UTC()
endUTC := end.UTC()
if !endUTC.After(startUTC) {
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startUTC.Truncate(time.Hour)
hourEnd := endUTC.Truncate(time.Hour)
if endUTC.After(hourEnd) {
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDayUTC(startUTC)
dayEnd := truncateToDayUTC(endUTC)
if endUTC.After(dayEnd) {
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
return err
}
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
func truncateToDay(t time.Time) time.Time {
return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {

View File

@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()

View File

@@ -0,0 +1,47 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}

View File

@@ -0,0 +1,28 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}

View File

@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
duration_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
) RETURNING id`
var id int64
@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
opsNullInt(input.DurationMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,
@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
}
where, args := buildOpsErrorLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where
countSQL := "SELECT COUNT(*) FROM ops_error_logs e " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
argsWithLimit := append(args, pageSize, offset)
selectSQL := `
SELECT
id,
created_at,
error_phase,
error_type,
severity,
COALESCE(upstream_status_code, status_code, 0),
COALESCE(platform, ''),
COALESCE(model, ''),
duration_ms,
COALESCE(client_request_id, ''),
COALESCE(request_id, ''),
COALESCE(error_message, ''),
user_id,
api_key_id,
account_id,
group_id,
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
COALESCE(request_path, ''),
stream
FROM ops_error_logs
e.id,
e.created_at,
e.error_phase,
e.error_type,
COALESCE(e.error_owner, ''),
COALESCE(e.error_source, ''),
e.severity,
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
COALESCE(e.is_retryable, false),
COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
COALESCE(u2.email, ''),
e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream
FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
` + where + `
ORDER BY created_at DESC
ORDER BY e.created_at DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
out := make([]*service.OpsErrorLog, 0, pageSize)
for rows.Next() {
var item service.OpsErrorLog
var latency sql.NullInt64
var statusCode sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
var accountID sql.NullInt64
var accountName string
var groupID sql.NullInt64
var groupName string
var userEmail string
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedByName string
var resolvedRetryID sql.NullInt64
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Phase,
&item.Type,
&item.Owner,
&item.Source,
&item.Severity,
&statusCode,
&item.Platform,
&item.Model,
&latency,
&item.IsRetryable,
&item.RetryCount,
&item.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedByName,
&resolvedRetryID,
&item.ClientRequestID,
&item.RequestID,
&item.Message,
&userID,
&userEmail,
&apiKeyID,
&accountID,
&accountName,
&groupID,
&groupName,
&clientIP,
&item.RequestPath,
&item.Stream,
); err != nil {
return nil, err
}
if latency.Valid {
v := int(latency.Int64)
item.LatencyMs = &v
if resolvedAt.Valid {
t := resolvedAt.Time
item.ResolvedAt = &t
}
if resolvedBy.Valid {
v := resolvedBy.Int64
item.ResolvedByUserID = &v
}
item.ResolvedByUserName = resolvedByName
if resolvedRetryID.Valid {
v := resolvedRetryID.Int64
item.ResolvedRetryID = &v
}
item.StatusCode = int(statusCode.Int64)
if clientIP.Valid {
@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := userID.Int64
item.UserID = &v
}
item.UserEmail = userEmail
if apiKeyID.Valid {
v := apiKeyID.Int64
item.APIKeyID = &v
@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := accountID.Int64
item.AccountID = &v
}
item.AccountName = accountName
if groupID.Valid {
v := groupID.Int64
item.GroupID = &v
}
item.GroupName = groupName
out = append(out, &item)
}
if err := rows.Err(); err != nil {
@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
q := `
SELECT
id,
created_at,
error_phase,
error_type,
severity,
COALESCE(upstream_status_code, status_code, 0),
COALESCE(platform, ''),
COALESCE(model, ''),
duration_ms,
COALESCE(client_request_id, ''),
COALESCE(request_id, ''),
COALESCE(error_message, ''),
COALESCE(error_body, ''),
upstream_status_code,
COALESCE(upstream_error_message, ''),
COALESCE(upstream_error_detail, ''),
COALESCE(upstream_errors::text, ''),
is_business_limited,
user_id,
api_key_id,
account_id,
group_id,
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
COALESCE(request_path, ''),
stream,
COALESCE(user_agent, ''),
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
COALESCE(request_body::text, ''),
request_body_truncated,
request_body_bytes,
COALESCE(request_headers::text, '')
FROM ops_error_logs
WHERE id = $1
e.id,
e.created_at,
e.error_phase,
e.error_type,
COALESCE(e.error_owner, ''),
COALESCE(e.error_source, ''),
e.severity,
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
COALESCE(e.is_retryable, false),
COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
COALESCE(e.error_body, ''),
e.upstream_status_code,
COALESCE(e.upstream_error_message, ''),
COALESCE(e.upstream_error_detail, ''),
COALESCE(e.upstream_errors::text, ''),
e.is_business_limited,
e.user_id,
COALESCE(u.email, ''),
e.api_key_id,
e.account_id,
COALESCE(a.name, ''),
e.group_id,
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream,
COALESCE(e.user_agent, ''),
e.auth_latency_ms,
e.routing_latency_ms,
e.upstream_latency_ms,
e.response_latency_ms,
e.time_to_first_token_ms,
COALESCE(e.request_body::text, ''),
e.request_body_truncated,
e.request_body_bytes,
COALESCE(e.request_headers::text, '')
FROM ops_error_logs e
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
WHERE e.id = $1
LIMIT 1`
var out service.OpsErrorLogDetail
var latency sql.NullInt64
var statusCode sql.NullInt64
var upstreamStatusCode sql.NullInt64
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedRetryID sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
@@ -318,11 +375,18 @@ LIMIT 1`
&out.CreatedAt,
&out.Phase,
&out.Type,
&out.Owner,
&out.Source,
&out.Severity,
&statusCode,
&out.Platform,
&out.Model,
&latency,
&out.IsRetryable,
&out.RetryCount,
&out.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedRetryID,
&out.ClientRequestID,
&out.RequestID,
&out.Message,
@@ -333,9 +397,12 @@ LIMIT 1`
&out.UpstreamErrors,
&out.IsBusinessLimited,
&userID,
&out.UserEmail,
&apiKeyID,
&accountID,
&out.AccountName,
&groupID,
&out.GroupName,
&clientIP,
&out.RequestPath,
&out.Stream,
@@ -355,9 +422,17 @@ LIMIT 1`
}
out.StatusCode = int(statusCode.Int64)
if latency.Valid {
v := int(latency.Int64)
out.LatencyMs = &v
if resolvedAt.Valid {
t := resolvedAt.Time
out.ResolvedAt = &t
}
if resolvedBy.Valid {
v := resolvedBy.Int64
out.ResolvedByUserID = &v
}
if resolvedRetryID.Valid {
v := resolvedRetryID.Int64
out.ResolvedRetryID = &v
}
if clientIP.Valid {
s := clientIP.String
@@ -487,9 +562,15 @@ SET
status = $2,
finished_at = $3,
duration_ms = $4,
result_request_id = $5,
result_error_id = $6,
error_message = $7
success = $5,
http_status_code = $6,
upstream_request_id = $7,
used_account_id = $8,
response_preview = $9,
response_truncated = $10,
result_request_id = $11,
result_error_id = $12,
error_message = $13
WHERE id = $1`
_, err := r.db.ExecContext(
@@ -499,8 +580,14 @@ WHERE id = $1`
strings.TrimSpace(input.Status),
nullTime(input.FinishedAt),
input.DurationMs,
nullBool(input.Success),
nullInt(input.HTTPStatusCode),
opsNullString(input.UpstreamRequestID),
nullInt64(input.UsedAccountID),
opsNullString(input.ResponsePreview),
nullBool(input.ResponseTruncated),
opsNullString(input.ResultRequestID),
opsNullInt64(input.ResultErrorID),
nullInt64(input.ResultErrorID),
opsNullString(input.ErrorMessage),
)
return err
@@ -526,6 +613,12 @@ SELECT
started_at,
finished_at,
duration_ms,
success,
http_status_code,
upstream_request_id,
used_account_id,
response_preview,
response_truncated,
result_request_id,
result_error_id,
error_message
@@ -540,6 +633,12 @@ LIMIT 1`
var startedAt sql.NullTime
var finishedAt sql.NullTime
var durationMs sql.NullInt64
var success sql.NullBool
var httpStatusCode sql.NullInt64
var upstreamRequestID sql.NullString
var usedAccountID sql.NullInt64
var responsePreview sql.NullString
var responseTruncated sql.NullBool
var resultRequestID sql.NullString
var resultErrorID sql.NullInt64
var errorMessage sql.NullString
@@ -555,6 +654,12 @@ LIMIT 1`
&startedAt,
&finishedAt,
&durationMs,
&success,
&httpStatusCode,
&upstreamRequestID,
&usedAccountID,
&responsePreview,
&responseTruncated,
&resultRequestID,
&resultErrorID,
&errorMessage,
@@ -579,6 +684,30 @@ LIMIT 1`
v := durationMs.Int64
out.DurationMs = &v
}
if success.Valid {
v := success.Bool
out.Success = &v
}
if httpStatusCode.Valid {
v := int(httpStatusCode.Int64)
out.HTTPStatusCode = &v
}
if upstreamRequestID.Valid {
s := upstreamRequestID.String
out.UpstreamRequestID = &s
}
if usedAccountID.Valid {
v := usedAccountID.Int64
out.UsedAccountID = &v
}
if responsePreview.Valid {
s := responsePreview.String
out.ResponsePreview = &s
}
if responseTruncated.Valid {
v := responseTruncated.Bool
out.ResponseTruncated = &v
}
if resultRequestID.Valid {
s := resultRequestID.String
out.ResultRequestID = &s
@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
return sql.NullTime{Time: t, Valid: true}
}
func nullBool(v *bool) sql.NullBool {
if v == nil {
return sql.NullBool{}
}
return sql.NullBool{Bool: *v, Valid: true}
}
func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if sourceErrorID <= 0 {
return nil, fmt.Errorf("invalid source_error_id")
}
if limit <= 0 {
limit = 50
}
if limit > 200 {
limit = 200
}
q := `
SELECT
r.id,
r.created_at,
COALESCE(r.requested_by_user_id, 0),
r.source_error_id,
COALESCE(r.mode, ''),
r.pinned_account_id,
COALESCE(pa.name, ''),
COALESCE(r.status, ''),
r.started_at,
r.finished_at,
r.duration_ms,
r.success,
r.http_status_code,
r.upstream_request_id,
r.used_account_id,
COALESCE(ua.name, ''),
r.response_preview,
r.response_truncated,
r.result_request_id,
r.result_error_id,
r.error_message
FROM ops_retry_attempts r
LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
LEFT JOIN accounts ua ON r.used_account_id = ua.id
WHERE r.source_error_id = $1
ORDER BY r.created_at DESC
LIMIT $2`
rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]*service.OpsRetryAttempt, 0, 16)
for rows.Next() {
var item service.OpsRetryAttempt
var pinnedAccountID sql.NullInt64
var pinnedAccountName string
var requestedBy sql.NullInt64
var startedAt sql.NullTime
var finishedAt sql.NullTime
var durationMs sql.NullInt64
var success sql.NullBool
var httpStatusCode sql.NullInt64
var upstreamRequestID sql.NullString
var usedAccountID sql.NullInt64
var usedAccountName string
var responsePreview sql.NullString
var responseTruncated sql.NullBool
var resultRequestID sql.NullString
var resultErrorID sql.NullInt64
var errorMessage sql.NullString
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&requestedBy,
&item.SourceErrorID,
&item.Mode,
&pinnedAccountID,
&pinnedAccountName,
&item.Status,
&startedAt,
&finishedAt,
&durationMs,
&success,
&httpStatusCode,
&upstreamRequestID,
&usedAccountID,
&usedAccountName,
&responsePreview,
&responseTruncated,
&resultRequestID,
&resultErrorID,
&errorMessage,
); err != nil {
return nil, err
}
item.RequestedByUserID = requestedBy.Int64
if pinnedAccountID.Valid {
v := pinnedAccountID.Int64
item.PinnedAccountID = &v
}
item.PinnedAccountName = pinnedAccountName
if startedAt.Valid {
t := startedAt.Time
item.StartedAt = &t
}
if finishedAt.Valid {
t := finishedAt.Time
item.FinishedAt = &t
}
if durationMs.Valid {
v := durationMs.Int64
item.DurationMs = &v
}
if success.Valid {
v := success.Bool
item.Success = &v
}
if httpStatusCode.Valid {
v := int(httpStatusCode.Int64)
item.HTTPStatusCode = &v
}
if upstreamRequestID.Valid {
item.UpstreamRequestID = &upstreamRequestID.String
}
if usedAccountID.Valid {
v := usedAccountID.Int64
item.UsedAccountID = &v
}
item.UsedAccountName = usedAccountName
if responsePreview.Valid {
item.ResponsePreview = &responsePreview.String
}
if responseTruncated.Valid {
v := responseTruncated.Bool
item.ResponseTruncated = &v
}
if resultRequestID.Valid {
item.ResultRequestID = &resultRequestID.String
}
if resultErrorID.Valid {
v := resultErrorID.Int64
item.ResultErrorID = &v
}
if errorMessage.Valid {
item.ErrorMessage = &errorMessage.String
}
out = append(out, &item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if errorID <= 0 {
return fmt.Errorf("invalid error id")
}
q := `
UPDATE ops_error_logs
SET
resolved = $2,
resolved_at = $3,
resolved_by_user_id = $4,
resolved_retry_id = $5
WHERE id = $1`
at := sql.NullTime{}
if resolvedAt != nil && !resolvedAt.IsZero() {
at = sql.NullTime{Time: resolvedAt.UTC(), Valid: true}
} else if resolved {
now := time.Now().UTC()
at = sql.NullTime{Time: now, Valid: true}
}
_, err := r.db.ExecContext(
ctx,
q,
errorID,
resolved,
at,
nullInt64(resolvedByUserID),
nullInt64(resolvedRetryID),
)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 8)
args := make([]any, 0, 8)
clauses := make([]string, 0, 12)
args := make([]any, 0, 12)
clauses = append(clauses, "1=1")
phaseFilter := ""
if filter != nil {
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
}
// ops_error_logs primarily stores client-visible error requests (status>=400),
// ops_error_logs stores client-visible error requests (status>=400),
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
// If Resolved is not specified, do not filter by resolved state (backward-compatible).
resolvedFilter := (*bool)(nil)
if filter != nil {
resolvedFilter = filter.Resolved
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "created_at >= $"+itoa(len(args)))
clauses = append(clauses, "e.created_at >= $"+itoa(len(args)))
}
if filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
// Keep time-window semantics consistent with other ops queries: [start, end)
clauses = append(clauses, "created_at < $"+itoa(len(args)))
clauses = append(clauses, "e.created_at < $"+itoa(len(args)))
}
if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p)
@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
}
if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
}
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
}
}
if resolvedFilter != nil {
args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
}
// View filter: errors vs excluded vs all.
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
view := ""
if filter != nil {
view = strings.ToLower(strings.TrimSpace(filter.View))
}
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
case "excluded":
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
}
// Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
}
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
args = append(args, like)
@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
}
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n)
}
return "WHERE " + strings.Join(clauses, " AND "), args
}

View File

@@ -354,7 +354,7 @@ SELECT
created_at
FROM ops_alert_events
` + where + `
ORDER BY fired_at DESC
ORDER BY fired_at DESC, id DESC
LIMIT ` + limitArg
rows, err := r.db.QueryContext(ctx, q, args...)
@@ -413,6 +413,43 @@ LIMIT ` + limitArg
return out, nil
}
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return nil, fmt.Errorf("invalid event id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE id = $1`
row := r.db.QueryRowContext(ctx, q, eventID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
Scan(dest ...any) error
}
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.RuleID <= 0 {
return nil, fmt.Errorf("invalid rule_id")
}
platform := strings.TrimSpace(input.Platform)
if platform == "" {
return nil, fmt.Errorf("invalid platform")
}
if input.Until.IsZero() {
return nil, fmt.Errorf("invalid until")
}
q := `
INSERT INTO ops_alert_silences (
rule_id,
platform,
group_id,
region,
until,
reason,
created_by,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
row := r.db.QueryRowContext(
ctx,
q,
input.RuleID,
platform,
opsNullInt64(input.GroupID),
opsNullString(input.Region),
input.Until,
opsNullString(input.Reason),
opsNullInt64(input.CreatedBy),
)
var out service.OpsAlertSilence
var groupID sql.NullInt64
var region sql.NullString
var createdBy sql.NullInt64
if err := row.Scan(
&out.ID,
&out.RuleID,
&out.Platform,
&groupID,
&region,
&out.Until,
&out.Reason,
&createdBy,
&out.CreatedAt,
); err != nil {
return nil, err
}
if groupID.Valid {
v := groupID.Int64
out.GroupID = &v
}
if region.Valid {
v := strings.TrimSpace(region.String)
if v != "" {
out.Region = &v
}
}
if createdBy.Valid {
v := createdBy.Int64
out.CreatedBy = &v
}
return &out, nil
}
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if r == nil || r.db == nil {
return false, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return false, fmt.Errorf("invalid rule id")
}
platform = strings.TrimSpace(platform)
if platform == "" {
return false, nil
}
if now.IsZero() {
now = time.Now().UTC()
}
q := `
SELECT 1
FROM ops_alert_silences
WHERE rule_id = $1
AND platform = $2
AND (group_id IS NOT DISTINCT FROM $3)
AND (region IS NOT DISTINCT FROM $4)
AND until > $5
LIMIT 1`
var dummy int
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
args = append(args, severity)
clauses = append(clauses, "severity = $"+itoa(len(args)))
}
if filter.EmailSent != nil {
args = append(args, *filter.EmailSent)
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, *filter.StartTime)
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
}
// Cursor pagination (descending by fired_at, then id)
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
args = append(args, *filter.BeforeFiredAt)
tsArg := "$" + itoa(len(args))
args = append(args, *filter.BeforeID)
idArg := "$" + itoa(len(args))
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
}
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if platform := strings.TrimSpace(filter.Platform); platform != "" {
args = append(args, platform)

View File

@@ -71,7 +71,9 @@ usage_agg AS (
error_base AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
platform AS platform,
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
COALESCE(platform, 'unknown') AS platform,
group_id AS group_id,
is_business_limited AS is_business_limited,
error_owner AS error_owner,

View File

@@ -0,0 +1,74 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const proxyLatencyKeyPrefix = "proxy:latency:"
func proxyLatencyKey(proxyID int64) string {
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
}
type proxyLatencyCache struct {
rdb *redis.Client
}
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
return &proxyLatencyCache{rdb: rdb}
}
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
results := make(map[int64]*service.ProxyLatencyInfo)
if len(proxyIDs) == 0 {
return results, nil
}
keys := make([]string, 0, len(proxyIDs))
for _, id := range proxyIDs {
keys = append(keys, proxyLatencyKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return results, err
}
for i, raw := range values {
if raw == nil {
continue
}
var payload []byte
switch v := raw.(type) {
case string:
payload = []byte(v)
case []byte:
payload = v
default:
continue
}
var info service.ProxyLatencyInfo
if err := json.Unmarshal(payload, &info); err != nil {
continue
}
results[proxyIDs[i]] = &info
}
return results, nil
}
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
if info == nil {
return nil
}
payload, err := json.Marshal(info)
if err != nil {
return err
}
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
type proxyProbeService struct {
ipInfoURL string
@@ -46,7 +50,7 @@ type proxyProbeService struct {
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
Status string `json:"status"`
Message string `json:"message"`
Query string `json:"query"`
City string `json:"city"`
Region string `json:"region"`
RegionName string `json:"regionName"`
Country string `json:"country"`
CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
ipInfo.Message = "ip-api request failed"
}
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
}
region := ipInfo.RegionName
if region == "" {
region = ipInfo.Region
}
return &service.ProxyExitInfo{
IP: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
IP: ipInfo.Query,
City: ipInfo.City,
Region: region,
Country: ipInfo.Country,
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}

View File

@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}

View File

@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil
}
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`, proxyID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]service.ProxyAccountSummary, 0)
for rows.Next() {
var (
id int64
name string
platform string
accType string
notes sql.NullString
)
if err := rows.Scan(&id, &name, &platform, &accType, &notes); err != nil {
return nil, err
}
var notesPtr *string
if notes.Valid {
notesPtr = &notes.String
}
out = append(out, service.ProxyAccountSummary{
ID: id,
Name: name,
Platform: platform,
Type: accType,
Notes: notesPtr,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")

View File

@@ -27,7 +27,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
OutboxPollIntervalSeconds: 1,
OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true,
},

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
stream,
duration_ms,
@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
log.Stream,
duration,
@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
stats := &DashboardStats{}
now := time.Now().UTC()
todayUTC := truncateToDayUTC(now)
now := timezone.Now()
todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
return nil, err
}
@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
}
stats := &DashboardStats{}
now := time.Now().UTC()
todayUTC := truncateToDayUTC(now)
now := timezone.Now()
todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
return nil, err
}
@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
hourStart := now.UTC().Truncate(time.Hour)
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
if err != sql.ErrNoRows {
return err
@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return result, nil
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := `
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
model,
COUNT(*) as requests,
@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
var totalAccountCost float64
if err := scanSingleRow(
ctx,
r.sql,
@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&totalAccountCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64
var cost float64
var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
var userCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens,
Cost: cost,
ActualCost: actualCost,
UserCost: userCost,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
var totalActualCost, totalStandardCost float64
var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalAccountCost += h.ActualCost
totalUserCost += h.UserCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalCost: totalAccountCost,
TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration,
@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
UserCost: history[i].UserCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests,
}
}
@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
UserCost: highestRequestDay.UserCost,
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
if err != nil {
models = []ModelStat{}
}
@@ -1994,36 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
var (
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
)
if err := scanner.Scan(
@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost,
&actualCost,
&rateMultiplier,
&accountRateMultiplier,
&billingType,
&stream,
&durationMs,
@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost,
ActualCost: actualCost,
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
out := v.Float64
return &out
}
func nullString(v *string) sql.NullString {
if v == nil || *v == "" {
return sql.NullString{}

View File

@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
m := 0.5
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m,
CreatedAt: timezone.Today().Add(2 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().NotNil(got.AccountRateMultiplier)
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
createdAt := timezone.Today().Add(1 * time.Hour)
m1 := 1.5
m2 := 0.0
_, err := s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m1,
CreatedAt: createdAt,
})
s.Require().NoError(err)
_, err = s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 1.0,
AccountRateMultiplier: &m2,
CreatedAt: createdAt,
})
s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(40), stats.Tokens)
// account cost = SUM(total_cost * account_rate_multiplier)
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
// standard cost = SUM(total_cost)
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
// user cost = SUM(actual_cost)
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
}
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
@@ -416,8 +488,8 @@ func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
dayStart := truncateToDayUTC(now)
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
// 如果当前时间早于 hour2则使用昨天的时间
if now.Before(hour2.Add(time.Hour)) {
dayStart = dayStart.Add(-24 * time.Hour)
@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}

View File

@@ -69,6 +69,7 @@ var ProviderSet = wire.NewSet(
NewGeminiTokenCache,
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,

View File

@@ -239,9 +239,10 @@ func TestAPIContracts(t *testing.T) {
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"billing_type": 0,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
"first_token_ms": 50,
@@ -262,11 +263,11 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/admin/settings",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
@@ -285,15 +286,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
})
},
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
})
},
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
@@ -435,7 +436,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -858,6 +859,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
return 0, errors.New("not implemented")
}
func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return nil, errors.New("not implemented")
}
type stubRedeemCodeRepo struct{}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
@@ -1229,11 +1234,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}

View File

@@ -81,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
ops.GET("/alert-events/:id", h.Admin.Ops.GetAlertEvent)
ops.PUT("/alert-events/:id/status", h.Admin.Ops.UpdateAlertEventStatus)
ops.POST("/alert-silences", h.Admin.Ops.CreateAlertSilence)
// Email notification config (DB-backed)
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
@@ -110,10 +113,26 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
}
// Error logs (MVP-1)
// Error logs (legacy)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
// Request errors (client-visible failures)
ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
// Upstream errors (independent upstream failures)
ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
// Request drilldown (success + error)
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
@@ -250,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
}

View File

@@ -9,16 +9,19 @@ import (
)
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
// RateMultiplier 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 使用指针用于兼容旧版本调度缓存Redis中缺字段的情况nil 表示按 1.0 处理。
RateMultiplier *float64
Status string
ErrorMessage string
LastUsedAt *time.Time
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func (a *Account) BillingRateMultiplier() float64 {
if a == nil || a.RateMultiplier == nil {
return 1.0
}
if *a.RateMultiplier < 0 {
return 1.0
}
return *a.RateMultiplier
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false

View File

@@ -0,0 +1,27 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
var a Account
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
require.Nil(t, a.RateMultiplier)
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
v := 0.0
a := Account{RateMultiplier: &v}
require.Equal(t, 0.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
v := -1.0
a := Account{RateMultiplier: &v}
require.Equal(t, 1.0, a.BillingRateMultiplier())
}

View File

@@ -63,14 +63,15 @@ type AccountRepository interface {
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
Name *string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
}
// CreateAccountRequest 创建账号请求

View File

@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
}
// WindowStats 窗口期统计
//
// cost: 账号口径费用total_cost * account_rate_multiplier
// standard_cost: 标准费用total_cost不含倍率
// user_cost: 用户/API Key 口径费用actual_cost受分组倍率影响
type WindowStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
}
// UsageProgress 使用量进度
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
}
windowStats = &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
// 缓存窗口统计1 分钟)
@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}, nil
}

View File

@@ -54,7 +54,8 @@ type AdminService interface {
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
@@ -136,6 +137,7 @@ type CreateAccountInput struct {
ProxyID *int64
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
@@ -151,8 +153,9 @@ type UpdateAccountInput struct {
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
Status string
GroupIDs *[]int64
ExpiresAt *int64
@@ -162,16 +165,17 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
Schedulable *bool
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
Status string
Schedulable *bool
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -220,23 +224,35 @@ type GenerateRedeemCodesInput struct {
ValidityDays int // 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
type ProxyBatchDeleteResult struct {
DeletedIDs []int64 `json:"deleted_ids"`
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
}
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyBatchDeleteSkipped struct {
ID int64 `json:"id"`
Reason string `json:"reason"`
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
IP string
City string
Region string
Country string
CountryCode string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
@@ -254,6 +270,7 @@ type adminServiceImpl struct {
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
}
@@ -267,6 +284,7 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
@@ -278,6 +296,7 @@ func NewAdminService(
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -817,6 +836,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} else {
account.AutoPauseOnExpired = true
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -869,6 +894,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
account.Status = input.Status
}
@@ -942,6 +973,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
@@ -959,6 +996,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
@@ -1069,6 +1109,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
if err != nil {
return nil, 0, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, result.Total, nil
}
@@ -1077,7 +1118,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
if err != nil {
return nil, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, nil
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
@@ -1097,6 +1143,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
// Probe latency asynchronously so creation isn't blocked by network timeout.
go s.probeProxyLatency(context.Background(), proxy)
return proxy, nil
}
@@ -1135,12 +1183,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
return err
}
if count > 0 {
return ErrProxyInUse
}
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []Account{}, 0, nil
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
result := &ProxyBatchDeleteResult{}
if len(ids) == 0 {
return result, nil
}
for _, id := range ids {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
if count > 0 {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: ErrProxyInUse.Error(),
})
continue
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
result.DeletedIDs = append(result.DeletedIDs, id)
}
return result, nil
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
@@ -1240,23 +1329,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return &ProxyTestResult{
Success: false,
Message: err.Error(),
}, nil
}
latency := latencyMs
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
}, nil
}
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil {
return
}
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
if err != nil {
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return
}
latency := latencyMs
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
@@ -1306,6 +1441,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 {
return
}
ids := make([]int64, 0, len(proxies))
for i := range proxies {
ids = append(ids, proxies[i].ID)
}
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
if err != nil {
log.Printf("Warning: load proxy latency cache failed: %v", err)
return
}
for i := range proxies {
info := latencies[proxies[i].ID]
if info == nil {
continue
}
if info.Success {
proxies[i].LatencyStatus = "success"
proxies[i].LatencyMs = info.LatencyMs
} else {
proxies[i].LatencyStatus = "failed"
}
proxies[i].LatencyMessage = info.Message
proxies[i].IPAddress = info.IPAddress
proxies[i].Country = info.Country
proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region
proxies[i].City = info.City
}
}
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
if s.proxyLatencyCache == nil || info == nil {
return
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
log.Printf("Warning: store proxy latency cache failed: %v", err)
}
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {

View File

@@ -12,9 +12,9 @@ import (
type accountRepoStubForBulkUpdate struct {
accountRepoStub
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {

View File

@@ -153,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
}
type proxyRepoStub struct {
deleteErr error
deletedIDs []int64
deleteErr error
countErr error
accountCount int64
deletedIDs []int64
}
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
}
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
panic("unexpected CountAccountsByProxyID call")
if s.countErr != nil {
return 0, s.countErr
}
return s.accountCount, nil
}
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
panic("unexpected ListAccountSummariesByProxyID call")
}
type redeemRepoStub struct {
@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
require.Equal(t, []int64{404}, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
repo := &proxyRepoStub{accountCount: 2}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 77)
require.ErrorIs(t, err, ErrProxyInUse)
require.Empty(t, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_Error(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &proxyRepoStub{deleteErr: deleteErr}

View File

@@ -564,6 +564,10 @@ urlFallbackLoop:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
}
if err != nil {
return nil, err
}
@@ -574,6 +578,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -615,6 +620,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -645,6 +651,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -697,6 +704,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
@@ -740,6 +748,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
@@ -770,6 +779,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: kind,
@@ -817,6 +827,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
@@ -1371,6 +1382,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -1412,6 +1424,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -1442,6 +1455,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -1543,6 +1557,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
@@ -1559,6 +1574,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",

View File

@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("not an antigravity oauth account")
}
cacheKey := antigravityTokenCacheKey(account)
cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
func antigravityTokenCacheKey(account *Account) string {
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID

View File

@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}

View File

@@ -1211,6 +1211,72 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
})
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
now := time.Now()
overloadUntil := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
})
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {

View File

@@ -511,6 +511,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
@@ -893,6 +899,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
@@ -977,6 +988,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -1450,6 +1466,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
if err != nil {
return nil, err
}
@@ -1466,6 +1485,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -1490,6 +1510,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
@@ -1541,6 +1562,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry_thinking",
@@ -1569,6 +1591,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
@@ -1627,6 +1650,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -1675,6 +1699,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry_exhausted_failover",
@@ -1741,6 +1766,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover_on_400",
@@ -2618,30 +2644,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
CreatedAt: time.Now(),
}
// 添加 UserAgent

View File

@@ -545,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -588,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "signature_error",
@@ -662,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
@@ -711,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
@@ -737,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
@@ -972,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -1036,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
@@ -1120,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
@@ -1143,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
@@ -1168,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
@@ -1300,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",

View File

@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
DeleteAccessToken(ctx context.Context, cacheKey string) error
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
ReleaseRefreshLock(ctx context.Context, cacheKey string) error

View File

@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("not a gemini oauth account")
}
cacheKey := geminiTokenCacheKey(account)
cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func geminiTokenCacheKey(account *Account) string {
func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return projectID

View File

@@ -186,6 +186,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
@@ -332,6 +337,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
@@ -653,6 +664,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
@@ -662,6 +678,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -696,6 +713,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
@@ -853,6 +871,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
@@ -883,6 +902,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
@@ -1432,28 +1452,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加 UserAgent

View File

@@ -3,6 +3,7 @@ package service
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
@@ -15,6 +16,129 @@ import (
"github.com/gin-gonic/gin"
)
type stubOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
out := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{

View File

@@ -206,7 +206,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
windowMinutes := rule.WindowMinutes
if windowMinutes <= 0 {
@@ -236,6 +236,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
// Scoped silencing: if a matching silence exists, skip creating a firing event.
if s.opsService != nil {
platform := strings.TrimSpace(scopePlatform)
region := scopeRegion
if platform != "" {
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
continue
}
}
}
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
@@ -359,9 +370,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
return required
}
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
if filters == nil {
return "", nil
return "", nil, nil
}
if v, ok := filters["platform"]; ok {
if s, ok := v.(string); ok {
@@ -392,7 +403,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
}
}
}
return platform, groupID
if v, ok := filters["region"]; ok {
if s, ok := v.(string); ok {
vv := strings.TrimSpace(s)
if vv != "" {
region = &vv
}
}
}
return platform, groupID, region
}
func (s *OpsAlertEvaluatorService) computeRuleMetric(
@@ -504,16 +523,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return 0, false
}
return overview.UpstreamErrorRate * 100, true
case "p95_latency_ms":
if overview.Duration.P95 == nil {
return 0, false
}
return float64(*overview.Duration.P95), true
case "p99_latency_ms":
if overview.Duration.P99 == nil {
return 0, false
}
return float64(*overview.Duration.P99), true
default:
return 0, false
}

View File

@@ -8,8 +8,9 @@ import "time"
// with the existing ops dashboard frontend (backup style).
const (
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
OpsAlertStatusManualResolved = "manual_resolved"
)
type OpsAlertRule struct {
@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertSilence struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id,omitempty"`
Region *string `json:"region,omitempty"`
Until time.Time `json:"until"`
Reason string `json:"reason"`
CreatedBy *int64 `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertEventFilter struct {
Limit int
// Cursor pagination (descending by fired_at, then id).
BeforeFiredAt *time.Time
BeforeID *int64
// Optional filters.
Status string
Severity string
Status string
Severity string
EmailSent *bool
StartTime *time.Time
EndTime *time.Time

View File

@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
return s.opsRepo.ListAlertEvents(ctx, filter)
}
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if eventID <= 0 {
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
}
return nil, err
}
if ev == nil {
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
}
return ev, nil
}
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
}
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if input == nil {
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
}
if input.RuleID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if strings.TrimSpace(input.Platform) == "" {
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
}
if input.Until.IsZero() {
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
}
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
if err != nil {
return nil, err
}
return created, nil
}
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return false, err
}
if s.opsRepo == nil {
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if ruleID <= 0 {
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if strings.TrimSpace(platform) == "" {
return false, nil
}
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
}
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
if eventID <= 0 {
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
if strings.TrimSpace(status) == "" {
status = strings.TrimSpace(status)
if status == "" {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)

View File

@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
}
// computeBusinessHealth calculates business health score (0-100)
// Components: SLA (50%) + Error Rate (30%) + Latency (20%)
// Components: Error Rate (50%) + TTFT (50%)
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
// SLA score: 99.5% → 100, 95% → 0 (linear)
slaScore := 100.0
slaPct := clampFloat64(overview.SLA*100, 0, 100)
if slaPct < 99.5 {
if slaPct >= 95 {
slaScore = (slaPct - 95) / 4.5 * 100
} else {
slaScore = 0
}
}
// Error rate score: 0.5% → 100, 5% → 0 (linear)
// Error rate score: 1% → 100, 10% → 0 (linear)
// Combines request errors and upstream errors
errorScore := 100.0
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
if combinedErrorPct > 0.5 {
if combinedErrorPct <= 5 {
errorScore = (5 - combinedErrorPct) / 4.5 * 100
if combinedErrorPct > 1.0 {
if combinedErrorPct <= 10.0 {
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
} else {
errorScore = 0
}
}
// Latency score: 1s → 100, 10s → 0 (linear)
// Uses P99 of duration (TTFT is less critical for overall health)
latencyScore := 100.0
if overview.Duration.P99 != nil {
p99 := float64(*overview.Duration.P99)
// TTFT score: 1s → 100, 3s → 0 (linear)
// Time to first token is critical for user experience
ttftScore := 100.0
if overview.TTFT.P99 != nil {
p99 := float64(*overview.TTFT.P99)
if p99 > 1000 {
if p99 <= 10000 {
latencyScore = (10000 - p99) / 9000 * 100
if p99 <= 3000 {
ttftScore = (3000 - p99) / 2000 * 100
} else {
latencyScore = 0
ttftScore = 0
}
}
}
// Weighted combination
return slaScore*0.5 + errorScore*0.3 + latencyScore*0.2
// Weighted combination: 50% error rate + 50% TTFT
return errorScore*0.5 + ttftScore*0.5
}
// computeInfraHealth calculates infrastructure health score (0-100)

View File

@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent: float64Ptr(75),
},
},
wantMin: 60,
wantMax: 85,
wantMin: 96,
wantMax: 97,
},
{
name: "DB failure",
@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent: float64Ptr(30),
},
},
wantMin: 25,
wantMax: 50,
wantMin: 84,
wantMax: 85,
},
{
name: "combined failures - business healthy + infra degraded",
@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 50,
wantMax: 60,
wantMin: 100,
wantMax: 100,
},
{
name: "error rate boundary 0.5%",
name: "error rate boundary 1%",
overview: &OpsDashboardOverview{
SLA: 0.995,
ErrorRate: 0.005,
SLA: 0.99,
ErrorRate: 0.01,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 95,
wantMin: 100,
wantMax: 100,
},
{
name: "latency boundary 1000ms",
name: "error rate 5%",
overview: &OpsDashboardOverview{
SLA: 0.995,
SLA: 0.95,
ErrorRate: 0.05,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 77,
wantMax: 78,
},
{
name: "TTFT boundary 2s",
overview: &OpsDashboardOverview{
SLA: 0.99,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(1000)},
TTFT: OpsPercentiles{P99: intPtr(2000)},
},
wantMin: 95,
wantMax: 100,
wantMin: 75,
wantMax: 75,
},
{
name: "upstream error dominates",
@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate: 0.03,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 75,
wantMin: 88,
wantMax: 90,
},
}

View File

@@ -6,24 +6,43 @@ type OpsErrorLog struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Phase string `json:"phase"`
Type string `json:"type"`
// Standardized classification
// - phase: request|auth|routing|upstream|network|internal
// - owner: client|provider|platform
// - source: client_request|upstream_http|gateway
Phase string `json:"phase"`
Type string `json:"type"`
Owner string `json:"error_owner"`
Source string `json:"error_source"`
Severity string `json:"severity"`
StatusCode int `json:"status_code"`
Platform string `json:"platform"`
Model string `json:"model"`
LatencyMs *int `json:"latency_ms"`
IsRetryable bool `json:"is_retryable"`
RetryCount int `json:"retry_count"`
Resolved bool `json:"resolved"`
ResolvedAt *time.Time `json:"resolved_at"`
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
ResolvedByUserName string `json:"resolved_by_user_name"`
ResolvedRetryID *int64 `json:"resolved_retry_id"`
ResolvedStatusRaw string `json:"-"`
ClientRequestID string `json:"client_request_id"`
RequestID string `json:"request_id"`
Message string `json:"message"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
UserID *int64 `json:"user_id"`
UserEmail string `json:"user_email"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
AccountName string `json:"account_name"`
GroupID *int64 `json:"group_id"`
GroupName string `json:"group_name"`
ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"`
@@ -67,9 +86,24 @@ type OpsErrorLogFilter struct {
GroupID *int64
AccountID *int64
StatusCodes []int
Phase string
Query string
StatusCodes []int
StatusCodesOther bool
Phase string
Owner string
Source string
Resolved *bool
Query string
UserQuery string // Search by user email
// Optional correlation keys for exact matching.
RequestID string
ClientRequestID string
// View controls error categorization for list endpoints.
// - errors: show actionable errors (exclude business-limited / 429 / 529)
// - excluded: only show excluded errors
// - all: show everything
View string
Page int
PageSize int
@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
SourceErrorID int64 `json:"source_error_id"`
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
PinnedAccountName string `json:"pinned_account_name"`
Status string `json:"status"`
StartedAt *time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at"`
DurationMs *int64 `json:"duration_ms"`
// Persisted execution results (best-effort)
Success *bool `json:"success"`
HTTPStatusCode *int `json:"http_status_code"`
UpstreamRequestID *string `json:"upstream_request_id"`
UsedAccountID *int64 `json:"used_account_id"`
UsedAccountName string `json:"used_account_name"`
ResponsePreview *string `json:"response_preview"`
ResponseTruncated *bool `json:"response_truncated"`
// Optional correlation
ResultRequestID *string `json:"result_request_id"`
ResultErrorID *int64 `json:"result_error_id"`

View File

@@ -14,6 +14,8 @@ type OpsRepository interface {
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
// Lightweight window stats (for realtime WS / quick sampling).
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
@@ -39,12 +41,17 @@ type OpsRepository interface {
DeleteAlertRule(ctx context.Context, id int64) error
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
// Alert silences
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
// It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON *string
DurationMs *int
TimeToFirstTokenMs *int64
RequestBodyJSON *string // sanitized json string (not raw bytes)
@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
FinishedAt time.Time
DurationMs int64
// Optional correlation
// Persisted execution results (best-effort)
Success *bool
HTTPStatusCode *int
UpstreamRequestID *string
UsedAccountID *int64
ResponsePreview *string
ResponseTruncated *bool
// Optional correlation (legacy fields kept)
ResultRequestID *string
ResultErrorID *int64

View File

@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
return w.totalWritten > int64(w.limit)
}
const (
OpsRetryModeUpstreamEvent = "upstream_event"
)
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
}
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if errorLog == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
if strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
}
var pinned *int64
if mode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID
} else {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
}
}
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
}
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
// idx is 0-based. It always pins the original event account_id.
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if idx < 0 {
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
}
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if errorLog == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
if err != nil {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
}
if idx >= len(events) {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
}
ev := events[idx]
if ev == nil {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
}
if ev.AccountID <= 0 {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
}
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
if upstreamBody == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
}
override := *errorLog
override.RequestBody = upstreamBody
pinned := ev.AccountID
// Persist as upstream_event, execute as upstream pinned retry.
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
}
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
}
}
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if strings.TrimSpace(errorLog.RequestBody) == "" {
if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
}
var pinned *int64
if mode == OpsRetryModeUpstream {
if execMode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID
} else {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
}
}
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
defer cancel()
execRes := s.executeRetry(execCtx, errorLog, mode, pinned)
execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
finishedAt := time.Now()
result.FinishedAt = finishedAt
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
msg := result.ErrorMessage
updateErrMsg = &msg
}
// Keep legacy result_request_id empty; use upstream_request_id instead.
var resultRequestID *string
if strings.TrimSpace(result.UpstreamRequestID) != "" {
v := result.UpstreamRequestID
resultRequestID = &v
}
finalStatus := result.Status
if strings.TrimSpace(finalStatus) == "" {
finalStatus = opsRetryStatusFailed
}
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
httpStatus := result.HTTPStatusCode
upstreamReqID := result.UpstreamRequestID
usedAccountID := result.UsedAccountID
preview := result.ResponsePreview
truncated := result.ResponseTruncated
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
ID: attemptID,
Status: finalStatus,
FinishedAt: finishedAt,
DurationMs: result.DurationMs,
ResultRequestID: resultRequestID,
ErrorMessage: updateErrMsg,
ID: attemptID,
Status: finalStatus,
FinishedAt: finishedAt,
DurationMs: result.DurationMs,
Success: &success,
HTTPStatusCode: &httpStatus,
UpstreamRequestID: &upstreamReqID,
UsedAccountID: usedAccountID,
ResponsePreview: &preview,
ResponseTruncated: &truncated,
ResultRequestID: resultRequestID,
ErrorMessage: updateErrMsg,
}); err != nil {
// Best-effort: retry itself already executed; do not fail the API response.
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
} else if success {
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
}
}
return result, nil

View File

@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
out.Detail = ""
}
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitized != "" {
out.UpstreamRequestBody = sanitized
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
out.Kind = out.Kind + ":request_body_truncated"
}
} else {
out.UpstreamRequestBody = ""
}
}
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter
if s.opsRepo == nil {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
}
return s.opsRepo.ListErrorLogs(ctx, filter)
result, err := s.opsRepo.ListErrorLogs(ctx, filter)
if err != nil {
log.Printf("[Ops] GetErrorLogs failed: %v", err)
return nil, err
}
return result, nil
}
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
return detail, nil
}
func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []*OpsRetryAttempt{}, nil
}
return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
}
return items, nil
}
func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
// Best-effort ensure the error exists
if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
}
func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
bytesLen = len(raw)
if len(raw) == 0 {
@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
}
}
// Last resort: store a minimal placeholder (still valid JSON).
placeholder := map[string]any{
"request_body_truncated": true,
// Last resort: keep JSON shape but drop big fields.
// This avoids downstream code that expects certain top-level keys from crashing.
if root, ok := decoded.(map[string]any); ok {
placeholder := shallowCopyMap(root)
placeholder["request_body_truncated"] = true
// Replace potentially huge arrays/strings, but keep the keys present.
for _, k := range []string{"messages", "contents", "input", "prompt"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = []any{}
}
}
for _, k := range []string{"text"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = ""
}
}
encoded4, err4 := json.Marshal(placeholder)
if err4 == nil {
if len(encoded4) <= maxBytes {
return string(encoded4), true, bytesLen
}
}
}
if model := extractString(decoded, "model"); model != "" {
placeholder["model"] = model
}
encoded4, err4 := json.Marshal(placeholder)
// Final fallback: minimal valid JSON.
encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
if err4 != nil {
return "", true, bytesLen
}
@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
}
return raw, false
}
func extractString(v any, key string) string {
root, ok := v.(map[string]any)
if !ok {
return ""
}
s, _ := root[key].(string)
return strings.TrimSpace(s)
}

View File

@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation: OpsAggregationSettings{
AggregationEnabled: false,
},
IgnoreCountTokensErrors: false,
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
IgnoreCountTokensErrors: false,
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
}
}
@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
func defaultOpsMetricThresholds() *OpsMetricThresholds {
slaMin := 99.5
latencyMax := 2000.0
ttftMax := 500.0
reqErrMax := 5.0
upstreamErrMax := 5.0
return &OpsMetricThresholds{
SLAPercentMin: &slaMin,
LatencyP99MsMax: &latencyMax,
TTFTp99MsMax: &ttftMax,
RequestErrorRatePercentMax: &reqErrMax,
UpstreamErrorRatePercentMax: &upstreamErrMax,
@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT
if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) {
return nil, errors.New("sla_percent_min must be between 0 and 100")
}
if cfg.LatencyP99MsMax != nil && *cfg.LatencyP99MsMax < 0 {
return nil, errors.New("latency_p99_ms_max must be >= 0")
}
if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 {
return nil, errors.New("ttft_p99_ms_max must be >= 0")
}

View File

@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct {
type OpsMetricThresholds struct {
SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红
LatencyP99MsMax *float64 `json:"latency_p99_ms_max,omitempty"` // 延迟P99高于此值变红
TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红
RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct {
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type OpsAdvancedSettings struct {
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}
type OpsDataRetentionSettings struct {

View File

@@ -15,6 +15,11 @@ const (
OpsUpstreamErrorMessageKey = "ops_upstream_error_message"
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
OpsUpstreamErrorsKey = "ops_upstream_errors"
// Best-effort capture of the current upstream request body so ops can
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
)
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct {
AtUnixMs int64 `json:"at_unix_ms,omitempty"`
// Context
Platform string `json:"platform,omitempty"`
AccountID int64 `json:"account_id,omitempty"`
Platform string `json:"platform,omitempty"`
AccountID int64 `json:"account_id,omitempty"`
AccountName string `json:"account_name,omitempty"`
// Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
// Best-effort upstream response capture (sanitized+trimmed).
UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
// Kind: http_error | request_error | retry_exhausted | failover
Kind string `json:"kind,omitempty"`
@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
}
ev.Platform = strings.TrimSpace(ev.Platform)
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind)
ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail)
@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
}
// If the caller didn't explicitly pass upstream request body but the gateway
// stored it on the context, attach it so ops can retry this specific attempt.
if ev.UpstreamRequestBody == "" {
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
if s, ok := v.(string); ok {
ev.UpstreamRequestBody = strings.TrimSpace(s)
}
}
}
var existing []*OpsUpstreamErrorEvent
if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
s := string(raw)
return &s
}
func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return []*OpsUpstreamErrorEvent{}, nil
}
var out []*OpsUpstreamErrorEvent
if err := json.Unmarshal([]byte(raw), &out); err != nil {
return nil, err
}
return out, nil
}

View File

@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
type ProxyWithAccountCount struct {
Proxy
AccountCount int64
AccountCount int64
LatencyMs *int64
LatencyStatus string
LatencyMessage string
IPAddress string
Country string
CountryCode string
Region string
City string
}
type ProxyAccountSummary struct {
ID int64
Name string
Platform string
Type string
Notes *string
}

View File

@@ -0,0 +1,23 @@
package service
import (
"context"
"time"
)
type ProxyLatencyInfo struct {
Success bool `json:"success"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyLatencyCache interface {
GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*ProxyLatencyInfo, error)
SetProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) error
}

View File

@@ -10,6 +10,7 @@ import (
var (
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts")
)
type ProxyRepository interface {
@@ -26,6 +27,7 @@ type ProxyRepository interface {
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
}
// CreateProxyRequest 创建代理请求

View File

@@ -3,7 +3,7 @@ package service
import (
"context"
"encoding/json"
"log"
"log/slog"
"net/http"
"strconv"
"strings"
@@ -15,15 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
settingService *SettingService
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
s.tokenCacheInvalidator = invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false
}
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
tempMatched := false
if statusCode != 401 {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" {
@@ -76,7 +85,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
if account.Type == AccountTypeOAuth &&
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) {
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
}
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
@@ -116,7 +132,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable = true
} else if statusCode >= 500 {
// 未启用自定义错误码时仅记录5xx错误
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
shouldDisable = false
}
}
@@ -163,7 +179,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil {
return true, err
}
@@ -188,7 +204,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -210,7 +226,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil {
return true, err
}
@@ -231,7 +247,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -288,20 +304,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return
}
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return
}
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg)
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
}
// handle429 处理429限流错误
@@ -313,7 +329,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
@@ -321,10 +337,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
@@ -333,7 +349,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
@@ -341,10 +357,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
}
// handle529 处理529过载错误
@@ -357,11 +373,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return
}
log.Printf("Account %d overloaded until %v", account.ID, until)
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
@@ -384,17 +400,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
}
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
}
}
}
@@ -413,7 +429,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
return nil
@@ -460,7 +476,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
}
}
@@ -563,17 +579,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
return true
}
@@ -597,13 +613,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置
if s.settingService == nil {
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
return false
}
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil {
log.Printf("[StreamTimeout] Failed to get settings: %v", err)
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
return false
}
@@ -620,14 +636,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil {
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
// 继续处理,使用 count=1
count = 1
}
}
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
// 检查是否达到阈值
if count < int64(settings.ThresholdCount) {
@@ -668,24 +683,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
}
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
return true
}
@@ -694,17 +709,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
return true
}

View File

@@ -0,0 +1,121 @@
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Len(t, invalidator.accounts, 1)
})
}
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}

View File

@@ -0,0 +1,35 @@
package service
import "context"
type TokenCacheInvalidator interface {
InvalidateToken(ctx context.Context, account *Account) error
}
type CompositeTokenCacheInvalidator struct {
geminiCache GeminiTokenCache
}
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator {
return &CompositeTokenCacheInvalidator{
geminiCache: geminiCache,
}
}
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
if c == nil || c.geminiCache == nil || account == nil {
return nil
}
if account.Type != AccountTypeOAuth {
return nil
}
switch account.Platform {
case PlatformGemini:
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account))
case PlatformAntigravity:
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account))
default:
return nil
}
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type geminiTokenCacheStub struct {
deletedKeys []string
deleteErr error
}
func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
return "", nil
}
func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
return nil
}
func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
s.deletedKeys = append(s.deletedKeys, cacheKey)
return s.deleteErr
}
func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
return nil
}
func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 10,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "project-x",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"project-x"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 99,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "ag-project",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
invalidator := NewCompositeTokenCacheInvalidator(nil)
account := &Account{
ID: 2,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}

View File

@@ -0,0 +1,153 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 100,
Credentials: map[string]any{
"project_id": "my-project-123",
},
},
expected: "my-project-123",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 101,
Credentials: map[string]any{
"project_id": " project-with-spaces ",
},
},
expected: "project-with-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 102,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "account:102",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 103,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "account:103",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 104,
Credentials: map[string]any{},
},
expected: "account:104",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 105,
Credentials: nil,
},
expected: "account:105",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GeminiTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestAntigravityTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 200,
Credentials: map[string]any{
"project_id": "ag-project-456",
},
},
expected: "ag:ag-project-456",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 201,
Credentials: map[string]any{
"project_id": " ag-project-spaces ",
},
},
expected: "ag:ag-project-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 202,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "ag:account:202",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 203,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "ag:account:203",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 204,
Credentials: map[string]any{},
},
expected: "ag:account:204",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 205,
Credentials: nil,
},
expected: "ag:account:205",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AntigravityTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -14,9 +14,10 @@ import (
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
accountRepo AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
accountRepo AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
stopCh chan struct{}
wg sync.WaitGroup
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator,
stopCh: make(chan struct{}),
}
// 注册平台特定的刷新器
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth &&
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else {
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
}
}
return nil
}

View File

@@ -0,0 +1,361 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
lastAccount *Account
updateErr error
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.lastAccount = account
return r.updateErr
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
}
func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
s.calls++
return s.err
}
type tokenRefresherStub struct {
credentials map[string]any
err error
}
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
return true
}
func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
return true
}
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
if r.err != nil {
return nil, r.err
}
return r.credentials, nil
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 6,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
}
func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg)
account := &Account{
ID: 7,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
}
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "ag-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 9,
Platform: PlatformGemini,
Type: AccountTypeAPIKey, // 非 OAuth
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // 其他平台
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to save credentials")
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 12,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("refresh failed"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("network error"), // 可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{name: "nil_error", err: nil, expected: false},
{name: "network_error", err: errors.New("network timeout"), expected: false},
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNonRetryableRefreshError(tt.err)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -33,6 +33,8 @@ type UsageLog struct {
TotalCost float64
ActualCost float64
RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64
BillingType int8
Stream bool

View File

@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
svc.Start()
return svc
}
@@ -108,10 +109,12 @@ func ProvideRateLimitService(
tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache,
settingService *SettingService,
tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetSettingService(settingService)
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc
}
@@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
NewCompositeTokenCacheInvalidator,
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,

View File

@@ -0,0 +1,14 @@
-- Add account billing rate multiplier and per-usage snapshot.
--
-- accounts.rate_multiplier: 账号计费倍率(>=0允许 0 表示该账号计费为 0
-- usage_logs.account_rate_multiplier: 每条 usage log 的账号倍率快照,用于实现
-- “倍率调整仅影响之后请求”,并支持同一天分段倍率加权统计。
--
-- 注意usage_logs.account_rate_multiplier 不做回填、不设置 NOT NULL。
-- 老数据为 NULL 时,统计口径按 1.0 处理COALESCE
ALTER TABLE IF EXISTS accounts
ADD COLUMN IF NOT EXISTS rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0;
ALTER TABLE IF EXISTS usage_logs
ADD COLUMN IF NOT EXISTS account_rate_multiplier DECIMAL(10,4);

View File

@@ -0,0 +1,28 @@
-- +goose Up
-- +goose StatementBegin
-- Ops alert silences: scoped (rule_id + platform + group_id + region)
CREATE TABLE IF NOT EXISTS ops_alert_silences (
id BIGSERIAL PRIMARY KEY,
rule_id BIGINT NOT NULL,
platform VARCHAR(64) NOT NULL,
group_id BIGINT,
region VARCHAR(64),
until TIMESTAMPTZ NOT NULL,
reason TEXT,
created_by BIGINT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_ops_alert_silences_lookup
ON ops_alert_silences (rule_id, platform, group_id, region, until);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TABLE IF EXISTS ops_alert_silences;
-- +goose StatementEnd

View File

@@ -0,0 +1,111 @@
-- Add resolution tracking to ops_error_logs, persist retry results, and standardize error classification enums.
--
-- This migration is intentionally idempotent.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- ============================================
-- 1) ops_error_logs: resolution fields
-- ============================================
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS resolved BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ;
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT;
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS resolved_retry_id BIGINT;
CREATE INDEX IF NOT EXISTS idx_ops_error_logs_resolved_time
ON ops_error_logs (resolved, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_ops_error_logs_unresolved_time
ON ops_error_logs (created_at DESC)
WHERE resolved = false;
-- ============================================
-- 2) ops_retry_attempts: persist execution results
-- ============================================
ALTER TABLE ops_retry_attempts
ADD COLUMN IF NOT EXISTS success BOOLEAN;
ALTER TABLE ops_retry_attempts
ADD COLUMN IF NOT EXISTS http_status_code INT;
ALTER TABLE ops_retry_attempts
ADD COLUMN IF NOT EXISTS upstream_request_id VARCHAR(128);
ALTER TABLE ops_retry_attempts
ADD COLUMN IF NOT EXISTS used_account_id BIGINT;
ALTER TABLE ops_retry_attempts
ADD COLUMN IF NOT EXISTS response_preview TEXT;
ALTER TABLE ops_retry_attempts
ADD COLUMN IF NOT EXISTS response_truncated BOOLEAN NOT NULL DEFAULT false;
CREATE INDEX IF NOT EXISTS idx_ops_retry_attempts_success_time
ON ops_retry_attempts (success, created_at DESC);
-- Backfill best-effort fields for existing rows.
UPDATE ops_retry_attempts
SET success = (LOWER(COALESCE(status, '')) = 'succeeded')
WHERE success IS NULL;
UPDATE ops_retry_attempts
SET upstream_request_id = result_request_id
WHERE upstream_request_id IS NULL AND result_request_id IS NOT NULL;
-- ============================================
-- 3) Standardize classification enums in ops_error_logs
--
-- New enums:
-- error_phase: request|auth|routing|upstream|network|internal
-- error_owner: client|provider|platform
-- error_source: client_request|upstream_http|gateway
-- ============================================
-- Owner: legacy sub2api => platform.
UPDATE ops_error_logs
SET error_owner = 'platform'
WHERE LOWER(COALESCE(error_owner, '')) = 'sub2api';
-- Owner: normalize empty/null to platform (best-effort).
UPDATE ops_error_logs
SET error_owner = 'platform'
WHERE COALESCE(TRIM(error_owner), '') = '';
-- Phase: map legacy phases.
UPDATE ops_error_logs
SET error_phase = CASE
WHEN COALESCE(TRIM(error_phase), '') = '' THEN 'internal'
WHEN LOWER(error_phase) IN ('billing', 'concurrency', 'response') THEN 'request'
WHEN LOWER(error_phase) IN ('scheduling') THEN 'routing'
WHEN LOWER(error_phase) IN ('request', 'auth', 'routing', 'upstream', 'network', 'internal') THEN LOWER(error_phase)
ELSE 'internal'
END;
-- Source: map legacy sources.
UPDATE ops_error_logs
SET error_source = CASE
WHEN COALESCE(TRIM(error_source), '') = '' THEN 'gateway'
WHEN LOWER(error_source) IN ('billing', 'concurrency') THEN 'client_request'
WHEN LOWER(error_source) IN ('upstream_http') THEN 'upstream_http'
WHEN LOWER(error_source) IN ('upstream_network') THEN 'gateway'
WHEN LOWER(error_source) IN ('internal') THEN 'gateway'
WHEN LOWER(error_source) IN ('client_request', 'upstream_http', 'gateway') THEN LOWER(error_source)
ELSE 'gateway'
END;
-- Auto-resolve recovered upstream errors (client status < 400).
UPDATE ops_error_logs
SET
resolved = true,
resolved_at = COALESCE(resolved_at, created_at)
WHERE resolved = false AND COALESCE(status_code, 0) > 0 AND COALESCE(status_code, 0) < 400;

View File

@@ -69,6 +69,14 @@ JWT_EXPIRE_HOUR=24
# Leave unset to use default ./config.yaml
#CONFIG_FILE=./config.yaml
# -----------------------------------------------------------------------------
# Rate Limiting (Optional)
# 速率限制(可选)
# -----------------------------------------------------------------------------
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529过载时的冷却时间分钟
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
# -----------------------------------------------------------------------------
# Gateway Scheduling (Optional)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB

View File

@@ -46,6 +46,10 @@ export interface TrendParams {
granularity?: 'day' | 'hour'
user_id?: number
api_key_id?: number
model?: string
account_id?: number
group_id?: number
stream?: boolean
}
export interface TrendResponse {
@@ -70,6 +74,10 @@ export interface ModelStatsParams {
end_date?: string
user_id?: number
api_key_id?: number
model?: string
account_id?: number
group_id?: number
stream?: boolean
}
export interface ModelStatsResponse {

View File

@@ -17,6 +17,47 @@ export interface OpsRequestOptions {
export interface OpsRetryRequest {
mode: OpsRetryMode
pinned_account_id?: number
force?: boolean
}
export interface OpsRetryAttempt {
id: number
created_at: string
requested_by_user_id: number
source_error_id: number
mode: string
pinned_account_id?: number | null
pinned_account_name?: string
status: string
started_at?: string | null
finished_at?: string | null
duration_ms?: number | null
success?: boolean | null
http_status_code?: number | null
upstream_request_id?: string | null
used_account_id?: number | null
used_account_name?: string
response_preview?: string | null
response_truncated?: boolean | null
result_request_id?: string | null
result_error_id?: number | null
error_message?: string | null
}
export type OpsUpstreamErrorEvent = {
at_unix_ms?: number
platform?: string
account_id?: number
account_name?: string
upstream_status_code?: number
upstream_request_id?: string
upstream_request_body?: string
kind?: string
message?: string
detail?: string
}
export interface OpsRetryResult {
@@ -626,8 +667,6 @@ export type MetricType =
| 'success_rate'
| 'error_rate'
| 'upstream_error_rate'
| 'p95_latency_ms'
| 'p99_latency_ms'
| 'cpu_usage_percent'
| 'memory_usage_percent'
| 'concurrency_queue_depth'
@@ -663,7 +702,7 @@ export interface AlertEvent {
id: number
rule_id: number
severity: OpsSeverity | string
status: 'firing' | 'resolved' | string
status: 'firing' | 'resolved' | 'manual_resolved' | string
title?: string
description?: string
metric_value?: number
@@ -701,10 +740,9 @@ export interface EmailNotificationConfig {
}
export interface OpsMetricThresholds {
sla_percent_min?: number | null // SLA低于此值变红
latency_p99_ms_max?: number | null // 延迟P99高于此值变红
ttft_p99_ms_max?: number | null // TTFT P99高于此值变红
request_error_rate_percent_max?: number | null // 请求错误率高于此值变红
sla_percent_min?: number | null // SLA低于此值变红
ttft_p99_ms_max?: number | null // TTFT P99高于此值变红
request_error_rate_percent_max?: number | null // 请求错误率高于此值变红
upstream_error_rate_percent_max?: number | null // 上游错误率高于此值变红
}
@@ -735,6 +773,8 @@ export interface OpsAdvancedSettings {
data_retention: OpsDataRetentionSettings
aggregation: OpsAggregationSettings
ignore_count_tokens_errors: boolean
ignore_context_canceled: boolean
ignore_no_available_accounts: boolean
auto_refresh_enabled: boolean
auto_refresh_interval_seconds: number
}
@@ -754,21 +794,37 @@ export interface OpsAggregationSettings {
export interface OpsErrorLog {
id: number
created_at: string
// Standardized classification
phase: OpsPhase
type: string
error_owner: 'client' | 'provider' | 'platform' | string
error_source: 'client_request' | 'upstream_http' | 'gateway' | string
severity: OpsSeverity
status_code: number
platform: string
model: string
latency_ms?: number | null
is_retryable: boolean
retry_count: number
resolved: boolean
resolved_at?: string | null
resolved_by_user_id?: number | null
resolved_retry_id?: number | null
client_request_id: string
request_id: string
message: string
user_id?: number | null
user_email: string
api_key_id?: number | null
account_id?: number | null
account_name: string
group_id?: number | null
group_name: string
client_ip?: string | null
request_path?: string
@@ -890,7 +946,9 @@ export async function getErrorDistribution(
return data
}
export async function listErrorLogs(params: {
export type OpsErrorListView = 'errors' | 'excluded' | 'all'
export type OpsErrorListQueryParams = {
page?: number
page_size?: number
time_range?: string
@@ -899,10 +957,20 @@ export async function listErrorLogs(params: {
platform?: string
group_id?: number | null
account_id?: number | null
phase?: string
error_owner?: string
error_source?: string
resolved?: string
view?: OpsErrorListView
q?: string
status_codes?: string
}): Promise<OpsErrorLogsResponse> {
status_codes_other?: string
}
// Legacy unified endpoints
export async function listErrorLogs(params: OpsErrorListQueryParams): Promise<OpsErrorLogsResponse> {
const { data } = await apiClient.get<OpsErrorLogsResponse>('/admin/ops/errors', { params })
return data
}
@@ -917,6 +985,70 @@ export async function retryErrorRequest(id: number, req: OpsRetryRequest): Promi
return data
}
export async function listRetryAttempts(errorId: number, limit = 50): Promise<OpsRetryAttempt[]> {
const { data } = await apiClient.get<OpsRetryAttempt[]>(`/admin/ops/errors/${errorId}/retries`, { params: { limit } })
return data
}
export async function updateErrorResolved(errorId: number, resolved: boolean): Promise<void> {
await apiClient.put(`/admin/ops/errors/${errorId}/resolve`, { resolved })
}
// New split endpoints
export async function listRequestErrors(params: OpsErrorListQueryParams): Promise<OpsErrorLogsResponse> {
const { data } = await apiClient.get<OpsErrorLogsResponse>('/admin/ops/request-errors', { params })
return data
}
export async function listUpstreamErrors(params: OpsErrorListQueryParams): Promise<OpsErrorLogsResponse> {
const { data } = await apiClient.get<OpsErrorLogsResponse>('/admin/ops/upstream-errors', { params })
return data
}
export async function getRequestErrorDetail(id: number): Promise<OpsErrorDetail> {
const { data } = await apiClient.get<OpsErrorDetail>(`/admin/ops/request-errors/${id}`)
return data
}
export async function getUpstreamErrorDetail(id: number): Promise<OpsErrorDetail> {
const { data } = await apiClient.get<OpsErrorDetail>(`/admin/ops/upstream-errors/${id}`)
return data
}
export async function retryRequestErrorClient(id: number): Promise<OpsRetryResult> {
const { data } = await apiClient.post<OpsRetryResult>(`/admin/ops/request-errors/${id}/retry-client`, {})
return data
}
export async function retryRequestErrorUpstreamEvent(id: number, idx: number): Promise<OpsRetryResult> {
const { data } = await apiClient.post<OpsRetryResult>(`/admin/ops/request-errors/${id}/upstream-errors/${idx}/retry`, {})
return data
}
export async function retryUpstreamError(id: number): Promise<OpsRetryResult> {
const { data } = await apiClient.post<OpsRetryResult>(`/admin/ops/upstream-errors/${id}/retry`, {})
return data
}
export async function updateRequestErrorResolved(errorId: number, resolved: boolean): Promise<void> {
await apiClient.put(`/admin/ops/request-errors/${errorId}/resolve`, { resolved })
}
export async function updateUpstreamErrorResolved(errorId: number, resolved: boolean): Promise<void> {
await apiClient.put(`/admin/ops/upstream-errors/${errorId}/resolve`, { resolved })
}
export async function listRequestErrorUpstreamErrors(
id: number,
params: OpsErrorListQueryParams = {},
options: { include_detail?: boolean } = {}
): Promise<PaginatedResponse<OpsErrorDetail>> {
const query: Record<string, any> = { ...params }
if (options.include_detail) query.include_detail = '1'
const { data } = await apiClient.get<PaginatedResponse<OpsErrorDetail>>(`/admin/ops/request-errors/${id}/upstream-errors`, { params: query })
return data
}
export async function listRequestDetails(params: OpsRequestDetailsParams): Promise<OpsRequestDetailsResponse> {
const { data } = await apiClient.get<OpsRequestDetailsResponse>('/admin/ops/requests', { params })
return data
@@ -942,11 +1074,45 @@ export async function deleteAlertRule(id: number): Promise<void> {
await apiClient.delete(`/admin/ops/alert-rules/${id}`)
}
export async function listAlertEvents(limit = 100): Promise<AlertEvent[]> {
const { data } = await apiClient.get<AlertEvent[]>('/admin/ops/alert-events', { params: { limit } })
export interface AlertEventsQuery {
limit?: number
status?: string
severity?: string
email_sent?: boolean
time_range?: string
start_time?: string
end_time?: string
before_fired_at?: string
before_id?: number
platform?: string
group_id?: number
}
export async function listAlertEvents(params: AlertEventsQuery = {}): Promise<AlertEvent[]> {
const { data } = await apiClient.get<AlertEvent[]>('/admin/ops/alert-events', { params })
return data
}
export async function getAlertEvent(id: number): Promise<AlertEvent> {
const { data } = await apiClient.get<AlertEvent>(`/admin/ops/alert-events/${id}`)
return data
}
export async function updateAlertEventStatus(id: number, status: 'resolved' | 'manual_resolved'): Promise<void> {
await apiClient.put(`/admin/ops/alert-events/${id}/status`, { status })
}
export async function createAlertSilence(payload: {
rule_id: number
platform: string
group_id?: number | null
region?: string | null
until: string
reason?: string
}): Promise<void> {
await apiClient.post('/admin/ops/alert-silences', payload)
}
// Email notification config
export async function getEmailNotificationConfig(): Promise<EmailNotificationConfig> {
const { data } = await apiClient.get<EmailNotificationConfig>('/admin/ops/email-notification/config')
@@ -1001,15 +1167,35 @@ export const opsAPI = {
getAccountAvailabilityStats,
getRealtimeTrafficSummary,
subscribeQPS,
// Legacy unified endpoints
listErrorLogs,
getErrorLogDetail,
retryErrorRequest,
listRetryAttempts,
updateErrorResolved,
// New split endpoints
listRequestErrors,
listUpstreamErrors,
getRequestErrorDetail,
getUpstreamErrorDetail,
retryRequestErrorClient,
retryRequestErrorUpstreamEvent,
retryUpstreamError,
updateRequestErrorResolved,
updateUpstreamErrorResolved,
listRequestErrorUpstreamErrors,
listRequestDetails,
listAlertRules,
createAlertRule,
updateAlertRule,
deleteAlertRule,
listAlertEvents,
getAlertEvent,
updateAlertEventStatus,
createAlertSilence,
getEmailNotificationConfig,
updateEmailNotificationConfig,
getAlertRuntimeSettings,

View File

@@ -4,7 +4,13 @@
*/
import { apiClient } from '../client'
import type { Proxy, CreateProxyRequest, UpdateProxyRequest, PaginatedResponse } from '@/types'
import type {
Proxy,
ProxyAccountSummary,
CreateProxyRequest,
UpdateProxyRequest,
PaginatedResponse
} from '@/types'
/**
* List all proxies with pagination
@@ -120,6 +126,7 @@ export async function testProxy(id: number): Promise<{
city?: string
region?: string
country?: string
country_code?: string
}> {
const { data } = await apiClient.post<{
success: boolean
@@ -129,6 +136,7 @@ export async function testProxy(id: number): Promise<{
city?: string
region?: string
country?: string
country_code?: string
}>(`/admin/proxies/${id}/test`)
return data
}
@@ -160,8 +168,8 @@ export async function getStats(id: number): Promise<{
* @param id - Proxy ID
* @returns List of accounts using the proxy
*/
export async function getProxyAccounts(id: number): Promise<PaginatedResponse<any>> {
const { data } = await apiClient.get<PaginatedResponse<any>>(`/admin/proxies/${id}/accounts`)
export async function getProxyAccounts(id: number): Promise<ProxyAccountSummary[]> {
const { data } = await apiClient.get<ProxyAccountSummary[]>(`/admin/proxies/${id}/accounts`)
return data
}
@@ -189,6 +197,17 @@ export async function batchCreate(
return data
}
export async function batchDelete(ids: number[]): Promise<{
deleted_ids: number[]
skipped: Array<{ id: number; reason: string }>
}> {
const { data } = await apiClient.post<{
deleted_ids: number[]
skipped: Array<{ id: number; reason: string }>
}>('/admin/proxies/batch-delete', { ids })
return data
}
export const proxiesAPI = {
list,
getAll,
@@ -201,7 +220,8 @@ export const proxiesAPI = {
testProxy,
getStats,
getProxyAccounts,
batchCreate
batchCreate,
batchDelete
}
export default proxiesAPI

View File

@@ -16,6 +16,7 @@ export interface AdminUsageStatsResponse {
total_tokens: number
total_cost: number
total_actual_cost: number
total_account_cost?: number
average_duration_ms: number
}

View File

@@ -73,11 +73,12 @@
</p>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.stats.accumulatedCost') }}
<span class="text-gray-400 dark:text-gray-500"
>({{ t('admin.accounts.stats.standardCost') }}: ${{
<span class="text-gray-400 dark:text-gray-500">
({{ t('usage.userBilled') }}: ${{ formatCost(stats.summary.total_user_cost) }} ·
{{ t('admin.accounts.stats.standardCost') }}: ${{
formatCost(stats.summary.total_standard_cost)
}})</span
>
}})
</span>
</p>
</div>
@@ -121,12 +122,15 @@
<p class="text-2xl font-bold text-gray-900 dark:text-white">
${{ formatCost(stats.summary.avg_daily_cost) }}
</p>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{
t('admin.accounts.stats.basedOnActualDays', {
days: stats.summary.actual_days_used
})
}}
<span class="text-gray-400 dark:text-gray-500">
({{ t('usage.userBilled') }}: ${{ formatCost(stats.summary.avg_daily_user_cost) }})
</span>
</p>
</div>
@@ -189,13 +193,17 @@
</div>
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.cost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.user_cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.requests')
@@ -240,13 +248,17 @@
}}</span>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.cost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-orange-600 dark:text-orange-400"
>${{ formatCost(stats.summary.highest_cost_day?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_cost_day?.user_cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.requests')
@@ -291,13 +303,17 @@
}}</span>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.cost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_request_day?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_request_day?.user_cost || 0) }}</span
>
</div>
</div>
</div>
</div>
@@ -397,13 +413,17 @@
}}</span>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.todayCost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.user_cost || 0) }}</span
>
</div>
</div>
</div>
</div>
@@ -517,14 +537,24 @@ const trendChartData = computed(() => {
labels: stats.value.history.map((h) => h.label),
datasets: [
{
label: t('admin.accounts.stats.cost') + ' (USD)',
data: stats.value.history.map((h) => h.cost),
label: t('usage.accountBilled') + ' (USD)',
data: stats.value.history.map((h) => h.actual_cost),
borderColor: '#3b82f6',
backgroundColor: 'rgba(59, 130, 246, 0.1)',
fill: true,
tension: 0.3,
yAxisID: 'y'
},
{
label: t('usage.userBilled') + ' (USD)',
data: stats.value.history.map((h) => h.user_cost),
borderColor: '#10b981',
backgroundColor: 'rgba(16, 185, 129, 0.08)',
fill: false,
tension: 0.3,
borderDash: [5, 5],
yAxisID: 'y'
},
{
label: t('admin.accounts.stats.requests'),
data: stats.value.history.map((h) => h.requests),
@@ -602,7 +632,7 @@ const lineChartOptions = computed(() => ({
},
title: {
display: true,
text: t('admin.accounts.stats.cost') + ' (USD)',
text: t('usage.accountBilled') + ' (USD)',
color: '#3b82f6',
font: {
size: 11

View File

@@ -32,15 +32,20 @@
formatTokens(stats.tokens)
}}</span>
</div>
<!-- Cost -->
<!-- Cost (Account) -->
<div class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400"
>{{ t('admin.accounts.stats.cost') }}:</span
>
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}:</span>
<span class="font-medium text-emerald-600 dark:text-emerald-400">{{
formatCurrency(stats.cost)
}}</span>
</div>
<!-- Cost (User/API Key) -->
<div v-if="stats.user_cost != null" class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}:</span>
<span class="font-medium text-gray-700 dark:text-gray-300">{{
formatCurrency(stats.user_cost)
}}</span>
</div>
</div>
<!-- No data -->

View File

@@ -459,7 +459,7 @@
</div>
<!-- Concurrency & Priority -->
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600 lg:grid-cols-3">
<div>
<div class="mb-3 flex items-center justify-between">
<label
@@ -516,6 +516,36 @@
aria-labelledby="bulk-edit-priority-label"
/>
</div>
<div>
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-rate-multiplier-label"
class="input-label mb-0"
for="bulk-edit-rate-multiplier-enabled"
>
{{ t('admin.accounts.billingRateMultiplier') }}
</label>
<input
v-model="enableRateMultiplier"
id="bulk-edit-rate-multiplier-enabled"
type="checkbox"
aria-controls="bulk-edit-rate-multiplier"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<input
v-model.number="rateMultiplier"
id="bulk-edit-rate-multiplier"
type="number"
min="0"
step="0.01"
:disabled="!enableRateMultiplier"
class="input"
:class="!enableRateMultiplier && 'cursor-not-allowed opacity-50'"
aria-labelledby="bulk-edit-rate-multiplier-label"
/>
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div>
</div>
<!-- Status -->
@@ -655,6 +685,7 @@ const enableInterceptWarmup = ref(false)
const enableProxy = ref(false)
const enableConcurrency = ref(false)
const enablePriority = ref(false)
const enableRateMultiplier = ref(false)
const enableStatus = ref(false)
const enableGroups = ref(false)
@@ -670,6 +701,7 @@ const interceptWarmupRequests = ref(false)
const proxyId = ref<number | null>(null)
const concurrency = ref(1)
const priority = ref(1)
const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
@@ -863,6 +895,10 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.priority = priority.value
}
if (enableRateMultiplier.value) {
updates.rate_multiplier = rateMultiplier.value
}
if (enableStatus.value) {
updates.status = status.value
}
@@ -923,6 +959,7 @@ const handleSubmit = async () => {
enableProxy.value ||
enableConcurrency.value ||
enablePriority.value ||
enableRateMultiplier.value ||
enableStatus.value ||
enableGroups.value
@@ -977,6 +1014,7 @@ watch(
enableProxy.value = false
enableConcurrency.value = false
enablePriority.value = false
enableRateMultiplier.value = false
enableStatus.value = false
enableGroups.value = false
@@ -991,6 +1029,7 @@ watch(
proxyId.value = null
concurrency.value = 1
priority.value = 1
rateMultiplier.value = 1
status.value = 'active'
groupIds.value = []
}

View File

@@ -1196,7 +1196,7 @@
<ProxySelector v-model="form.proxy_id" :proxies="proxies" />
</div>
<div class="grid grid-cols-2 gap-4">
<div class="grid grid-cols-2 gap-4 lg:grid-cols-3">
<div>
<label class="input-label">{{ t('admin.accounts.concurrency') }}</label>
<input v-model.number="form.concurrency" type="number" min="1" class="input" />
@@ -1212,6 +1212,11 @@
/>
<p class="input-hint">{{ t('admin.accounts.priorityHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label>
<input v-model.number="form.rate_multiplier" type="number" min="0" step="0.01" class="input" />
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div>
</div>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.expiresAt') }}</label>
@@ -1832,6 +1837,7 @@ const form = reactive({
proxy_id: null as number | null,
concurrency: 10,
priority: 1,
rate_multiplier: 1,
group_ids: [] as number[],
expires_at: null as number | null
})
@@ -2119,6 +2125,7 @@ const resetForm = () => {
form.proxy_id = null
form.concurrency = 10
form.priority = 1
form.rate_multiplier = 1
form.group_ids = []
form.expires_at = null
accountCategory.value = 'oauth-based'
@@ -2272,6 +2279,7 @@ const createAccountAndFinish = async (
proxy_id: form.proxy_id,
concurrency: form.concurrency,
priority: form.priority,
rate_multiplier: form.rate_multiplier,
group_ids: form.group_ids,
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
@@ -2490,6 +2498,7 @@ const handleCookieAuth = async (sessionKey: string) => {
proxy_id: form.proxy_id,
concurrency: form.concurrency,
priority: form.priority,
rate_multiplier: form.rate_multiplier,
group_ids: form.group_ids,
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value

View File

@@ -549,7 +549,7 @@
<ProxySelector v-model="form.proxy_id" :proxies="proxies" />
</div>
<div class="grid grid-cols-2 gap-4">
<div class="grid grid-cols-2 gap-4 lg:grid-cols-3">
<div>
<label class="input-label">{{ t('admin.accounts.concurrency') }}</label>
<input v-model.number="form.concurrency" type="number" min="1" class="input" />
@@ -564,6 +564,11 @@
data-tour="account-form-priority"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label>
<input v-model.number="form.rate_multiplier" type="number" min="0" step="0.01" class="input" />
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div>
</div>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.expiresAt') }}</label>
@@ -807,6 +812,7 @@ const form = reactive({
proxy_id: null as number | null,
concurrency: 1,
priority: 1,
rate_multiplier: 1,
status: 'active' as 'active' | 'inactive',
group_ids: [] as number[],
expires_at: null as number | null
@@ -834,6 +840,7 @@ watch(
form.proxy_id = newAccount.proxy_id
form.concurrency = newAccount.concurrency
form.priority = newAccount.priority
form.rate_multiplier = newAccount.rate_multiplier ?? 1
form.status = newAccount.status as 'active' | 'inactive'
form.group_ids = newAccount.group_ids || []
form.expires_at = newAccount.expires_at ?? null

Some files were not shown because too many files have changed in this diff Show More