mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
151 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 | ||
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
4e8615f276 | ||
|
|
45456fa24c | ||
|
|
6344fa2a86 | ||
|
|
29b0e4a8a5 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
53ad1645cf | ||
|
|
af9c4a7dd0 | ||
|
|
6826149a8f |
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
||||
# Go 源代码文件
|
||||
*.go text eol=lf
|
||||
|
||||
# 前端 源代码文件
|
||||
*.ts text eol=lf
|
||||
*.tsx text eol=lf
|
||||
*.js text eol=lf
|
||||
*.jsx text eol=lf
|
||||
*.vue text eol=lf
|
||||
|
||||
# Shell 脚本
|
||||
*.sh text eol=lf
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
|
||||
@@ -63,6 +63,8 @@ dockers:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -76,6 +78,8 @@ dockers:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -89,6 +93,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -102,6 +108,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
|
||||
31
Dockerfile
31
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,21 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
@@ -102,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||
|
||||
# Switch to non-root user
|
||||
USER sub2api
|
||||
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
# Expose port (can be overridden by SERVER_PORT env var)
|
||||
EXPOSE 8080
|
||||
@@ -112,5 +132,6 @@ EXPOSE 8080
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
# Run the application
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/sub2api"]
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||
# restore work in the runtime container without requiring Docker socket access.
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
USER sub2api
|
||||
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/sub2api"]
|
||||
|
||||
30
README.md
30
README.md
@@ -8,27 +8,31 @@
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||
|
||||
---
|
||||
|
||||
## Demo
|
||||
|
||||
Try Sub2API online: **https://demo.sub2api.org/**
|
||||
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||
|
||||
| Email | Password |
|
||||
|-------|----------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## Overview
|
||||
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
|
||||
## Features
|
||||
|
||||
@@ -41,6 +45,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||
|
||||
## Don't Want to Self-Host?
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
Community projects that extend or integrate with Sub2API:
|
||||
@@ -61,10 +74,15 @@ Community projects that extend or integrate with Sub2API:
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
## Nginx Reverse Proxy Note
|
||||
|
||||
- Dependency Security: `docs/dependency-security.md`
|
||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
||||
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||
|
||||
---
|
||||
|
||||
|
||||
33
README_CN.md
33
README_CN.md
@@ -8,27 +8,30 @@
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**AI API 网关平台 - 订阅配额分发管理**
|
||||
|
||||
[English](README.md) | 中文
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||
---
|
||||
|
||||
## 在线体验
|
||||
|
||||
体验地址:**https://v2.pincc.ai/**
|
||||
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||
|
||||
| 邮箱 | 密码 |
|
||||
|------|------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## 项目概述
|
||||
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
|
||||
## 核心功能
|
||||
|
||||
@@ -41,6 +44,15 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||
|
||||
## 不想自建?试试官方中转
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
围绕 Sub2API 的社区扩展与集成项目:
|
||||
@@ -61,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
|
||||
---
|
||||
|
||||
## 文档
|
||||
## Nginx 反向代理注意事项
|
||||
|
||||
- 依赖安全:`docs/dependency-security.md`
|
||||
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||
|
||||
---
|
||||
|
||||
## OpenAI Responses 兼容注意事项
|
||||
|
||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||
|
||||
## 部署方式
|
||||
|
||||
### 方式一:脚本安装(推荐)
|
||||
|
||||
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -94,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -230,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
@@ -124,6 +123,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
@@ -132,20 +132,26 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageCache := service.NewUsageCache()
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -162,10 +168,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
@@ -201,7 +207,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -228,11 +234,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -285,6 +291,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -420,6 +427,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -27,12 +27,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -1718,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||
|
||||
|
||||
205
backend/internal/handler/admin/backup_handler.go
Normal file
205
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Accepted(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Accepted(c, record)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -512,6 +513,8 @@ func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"total_requests": ranking.TotalRequests,
|
||||
"total_tokens": ranking.TotalTokens,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
@@ -602,3 +605,41 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||
// GET /api/v1/admin/dashboard/user-breakdown
|
||||
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
dim := usagestats.UserBreakdownDimension{}
|
||||
if v := c.Query("group_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.GroupID = id
|
||||
}
|
||||
}
|
||||
dim.Model = c.Query("model")
|
||||
dim.Endpoint = c.Query("endpoint")
|
||||
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||
|
||||
limit := 50
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||
c.Request.Context(), startTime, endTime, dim, limit,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"users": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,6 +61,8 @@ func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
TotalRequests: 44,
|
||||
TotalTokens: 1234,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -164,6 +166,8 @@ func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock repo ---
|
||||
|
||||
type userBreakdownRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
capturedDim usagestats.UserBreakdownDimension
|
||||
capturedLimit int
|
||||
result []usagestats.UserBreakdownItem
|
||||
}
|
||||
|
||||
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||
_ context.Context, _, _ time.Time,
|
||||
dim usagestats.UserBreakdownDimension, limit int,
|
||||
) ([]usagestats.UserBreakdownItem, error) {
|
||||
r.capturedDim = dim
|
||||
r.capturedLimit = limit
|
||||
if r.result != nil {
|
||||
return r.result, nil
|
||||
}
|
||||
return []usagestats.UserBreakdownItem{}, nil
|
||||
}
|
||||
|
||||
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
h := NewDashboardHandler(svc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||
return router
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||
require.Empty(t, repo.capturedDim.Model)
|
||||
require.Empty(t, repo.capturedDim.Endpoint)
|
||||
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 100, repo.capturedLimit)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
// limit > 200 should fall back to default 50
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 50, repo.capturedLimit)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{
|
||||
result: []usagestats.UserBreakdownItem{
|
||||
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||
},
|
||||
}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Users, 2)
|
||||
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.Data.Users)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||
require.Empty(t, repo.capturedDim.Model)
|
||||
require.Empty(t, repo.capturedDim.Endpoint)
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,27 +17,80 @@ import (
|
||||
|
||||
// GroupHandler handles admin group management
|
||||
type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
adminService service.AdminService
|
||||
dashboardService *service.DashboardService
|
||||
groupCapacityService *service.GroupCapacityService
|
||||
}
|
||||
|
||||
type optionalLimitField struct {
|
||||
set bool
|
||||
value *float64
|
||||
}
|
||||
|
||||
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||
f.set = true
|
||||
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var number float64
|
||||
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
number, err = strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||
}
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||
}
|
||||
|
||||
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||
if !f.set {
|
||||
return nil
|
||||
}
|
||||
if f.value != nil {
|
||||
return f.value
|
||||
}
|
||||
zero := 0.0
|
||||
return &zero
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
adminService: adminService,
|
||||
adminService: adminService,
|
||||
dashboardService: dashboardService,
|
||||
groupCapacityService: groupCapacityService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -62,16 +119,16 @@ type CreateGroupRequest struct {
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -191,9 +248,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -244,9 +301,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -311,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
||||
_ = groupID // TODO: implement actual stats
|
||||
}
|
||||
|
||||
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||
userTZ := c.Query("timezone")
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
|
||||
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group usage summary")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||
// GET /api/v1/admin/groups/capacity-summary
|
||||
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group capacity summary")
|
||||
return
|
||||
}
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
// GetGroupAPIKeys handles getting API keys in a group
|
||||
// GET /api/v1/admin/groups/:id/api-keys
|
||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
|
||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
FrontendURL: settings.FrontendURL,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Frontend URL 验证
|
||||
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||
if req.FrontendURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
FrontendURL: updatedSettings.FrontendURL,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.FrontendURL != after.FrontendURL {
|
||||
changed = append(changed, "frontend_url")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
|
||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
status := c.Query("status")
|
||||
platform := c.Query("platform")
|
||||
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
ActiveAccountCount: g.ActiveAccountCount,
|
||||
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
@@ -264,8 +266,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -281,6 +283,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -498,6 +525,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
|
||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
inboundEndpoint := "/v1/chat/completions"
|
||||
upstreamEndpoint := "/v1/responses"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
InboundEndpoint: &inboundEndpoint,
|
||||
UpstreamEndpoint: &upstreamEndpoint,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, userDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
@@ -81,6 +82,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +115,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
|
||||
@@ -122,9 +122,11 @@ type AdminGroup struct {
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||
|
||||
// 分组排序
|
||||
SortOrder int `json:"sort_order"`
|
||||
@@ -203,6 +205,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -324,9 +336,13 @@ type UsageLog struct {
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
174
backend/internal/handler/endpoint.go
Normal file
174
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Canonical inbound / upstream endpoint paths.
|
||||
// All normalization and derivation reference this single set
|
||||
// of constants — add new paths HERE when a new API surface
|
||||
// is introduced.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointGeminiModels = "/v1beta/models"
|
||||
)
|
||||
|
||||
// gin.Context keys used by the middleware and helpers below.
|
||||
const (
|
||||
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Normalization functions
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||
//
|
||||
// "/antigravity/v1/messages" → "/v1/messages"
|
||||
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||
func NormalizeInboundEndpoint(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
switch {
|
||||
case strings.Contains(path, EndpointChatCompletions):
|
||||
return EndpointChatCompletions
|
||||
case strings.Contains(path, EndpointMessages):
|
||||
return EndpointMessages
|
||||
case strings.Contains(path, EndpointResponses):
|
||||
return EndpointResponses
|
||||
case strings.Contains(path, EndpointGeminiModels):
|
||||
return EndpointGeminiModels
|
||||
default:
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||
// account platform and the normalized inbound endpoint.
|
||||
//
|
||||
// Platform-specific rules:
|
||||
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||
// such as /v1/responses/compact preserved from the raw URL).
|
||||
// - Anthropic → /v1/messages
|
||||
// - Gemini → /v1beta/models
|
||||
// - Sora → /v1/chat/completions
|
||||
// - Antigravity routes may target either Claude or Gemini, so the
|
||||
// inbound endpoint is used to distinguish.
|
||||
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
inbound = strings.TrimSpace(inbound)
|
||||
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
// OpenAI forwards everything to the Responses API.
|
||||
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||
return EndpointResponses + suffix
|
||||
}
|
||||
return EndpointResponses
|
||||
|
||||
case service.PlatformAnthropic:
|
||||
return EndpointMessages
|
||||
|
||||
case service.PlatformGemini:
|
||||
return EndpointGeminiModels
|
||||
|
||||
case service.PlatformSora:
|
||||
return EndpointChatCompletions
|
||||
|
||||
case service.PlatformAntigravity:
|
||||
// Antigravity accounts serve both Claude and Gemini.
|
||||
if inbound == EndpointGeminiModels {
|
||||
return EndpointGeminiModels
|
||||
}
|
||||
return EndpointMessages
|
||||
}
|
||||
|
||||
// Unknown platform — fall back to inbound.
|
||||
return inbound
|
||||
}
|
||||
|
||||
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||
// Returns "" when there is no meaningful suffix.
|
||||
func responsesSubpathSuffix(rawPath string) string {
|
||||
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||
idx := strings.LastIndex(trimmed, "/responses")
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
suffix := trimmed[idx+len("/responses"):]
|
||||
if suffix == "" || suffix == "/" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(suffix, "/") {
|
||||
return ""
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Middleware
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||
// canonical inbound endpoint in gin.Context so that every handler in
|
||||
// the chain can read it via GetInboundEndpoint.
|
||||
//
|
||||
// Apply this middleware to all gateway route groups.
|
||||
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Context helpers — used by handlers before building
|
||||
// RecordUsageInput / RecordUsageLongContextInput.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||
func GetInboundEndpoint(c *gin.Context) string {
|
||||
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: normalize on the fly.
|
||||
path := ""
|
||||
if c != nil {
|
||||
path = c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
}
|
||||
return NormalizeInboundEndpoint(path)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||
// and the account platform. Handlers call this after scheduling an
|
||||
// account, passing account.Platform.
|
||||
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||
inbound := GetInboundEndpoint(c)
|
||||
rawPath := ""
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
rawPath = c.Request.URL.Path
|
||||
}
|
||||
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||
}
|
||||
159
backend/internal/handler/endpoint_test.go
Normal file
159
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() { gin.SetMode(gin.TestMode) }
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// NormalizeInboundEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
// Direct canonical paths.
|
||||
{"/v1/messages", EndpointMessages},
|
||||
{"/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1beta/models", EndpointGeminiModels},
|
||||
|
||||
// Prefixed paths (antigravity, openai, sora).
|
||||
{"/antigravity/v1/messages", EndpointMessages},
|
||||
{"/openai/v1/responses", EndpointResponses},
|
||||
{"/openai/v1/responses/compact", EndpointResponses},
|
||||
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||
|
||||
// Gin route patterns with wildcards.
|
||||
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||
{"/v1/responses/*subpath", EndpointResponses},
|
||||
|
||||
// Unknown path is returned as-is.
|
||||
{"/v1/embeddings", "/v1/embeddings"},
|
||||
{"", ""},
|
||||
{" /v1/messages ", EndpointMessages},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// DeriveUpstreamEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound string
|
||||
rawPath string
|
||||
platform string
|
||||
want string
|
||||
}{
|
||||
// Anthropic.
|
||||
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||
|
||||
// Gemini.
|
||||
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||
|
||||
// Sora.
|
||||
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||
|
||||
// OpenAI — always /v1/responses.
|
||||
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||
|
||||
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||
|
||||
// Unknown platform — passthrough.
|
||||
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// responsesSubpathSuffix
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{"/v1/responses", ""},
|
||||
{"/v1/responses/", ""},
|
||||
{"/v1/responses/compact", "/compact"},
|
||||
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||
{"/v1/messages", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.raw, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// InboundEndpointMiddleware + context helpers
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(InboundEndpointMiddleware())
|
||||
|
||||
var captured string
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
captured = GetInboundEndpoint(c)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, EndpointMessages, captured)
|
||||
}
|
||||
|
||||
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||
|
||||
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||
got := GetInboundEndpoint(c)
|
||||
require.Equal(t, EndpointMessages, got)
|
||||
}
|
||||
|
||||
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||
|
||||
// Simulate middleware.
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, "/v1/responses/compact", got)
|
||||
}
|
||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -435,6 +442,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
@@ -444,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -637,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -706,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -739,6 +761,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
@@ -748,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -913,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
||||
}
|
||||
if s := c.Query("end_date"); s != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
||||
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||
}
|
||||
}
|
||||
return startTime, endTime
|
||||
|
||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||
|
||||
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||
// 具体验证:
|
||||
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||
|
||||
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||
|
||||
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||
}
|
||||
|
||||
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||
|
||||
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||
|
||||
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx,
|
||||
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||
// c.Writer.Size() 仍为 -1
|
||||
|
||||
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||
require.False(t, guardTriggered,
|
||||
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
||||
return []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||
}`)
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
||||
System: []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||
}
|
||||
|
||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
||||
"system": []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||
})
|
||||
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||
|
||||
@@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@@ -262,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "responses root maps to responses upstream",
|
||||
path: "/v1/responses",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
{
|
||||
name: "responses compact keeps compact suffix",
|
||||
path: "/openai/v1/responses/compact",
|
||||
want: "/v1/responses/compact",
|
||||
},
|
||||
{
|
||||
name: "responses nested suffix preserved",
|
||||
path: "/openai/v1/responses/compact/detail",
|
||||
want: "/v1/responses/compact/detail",
|
||||
},
|
||||
{
|
||||
name: "non responses path uses platform fallback",
|
||||
path: "/v1/messages",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -362,6 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -655,14 +657,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@@ -743,6 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -1240,6 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
|
||||
@@ -26,6 +26,22 @@ const (
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
opsErrInvalidAPIKey = "invalid_api_key"
|
||||
opsErrAPIKeyRequired = "api_key_required"
|
||||
opsErrInsufficientBalance = "insufficient balance"
|
||||
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE":
|
||||
case opsCodeInsufficientBalance:
|
||||
return "billing_error"
|
||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "subscription_error"
|
||||
default:
|
||||
return "api_error"
|
||||
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
}
|
||||
|
||||
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if insufficient balance errors should be ignored
|
||||
if settings.IgnoreInsufficientBalanceErrors {
|
||||
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
@@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
|
||||
@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
|
||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
@@ -334,9 +334,23 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// 设置结束时间为当天结束
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
// 使用 period 参数
|
||||
period := c.DefaultQuery("period", "today")
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -114,10 +124,68 @@ type IneligibleTier struct {
|
||||
type LoadCodeAssistResponse struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
|
||||
type PaidTierInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
|
||||
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias PaidTierInfo
|
||||
var raw alias
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = PaidTierInfo(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AvailableCredit 表示一条 AI Credits 余额记录。
|
||||
type AvailableCredit struct {
|
||||
CreditType string `json:"creditType,omitempty"`
|
||||
CreditAmount string `json:"creditAmount,omitempty"`
|
||||
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
|
||||
}
|
||||
|
||||
// GetAmount 将 creditAmount 解析为浮点数。
|
||||
func (c *AvailableCredit) GetAmount() float64 {
|
||||
if c.CreditAmount == "" {
|
||||
return 0
|
||||
}
|
||||
var value float64
|
||||
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
|
||||
return value
|
||||
}
|
||||
|
||||
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
|
||||
func (c *AvailableCredit) GetMinimumAmount() float64 {
|
||||
if c.MinimumCreditAmountForUsage == "" {
|
||||
return 0
|
||||
}
|
||||
var value float64
|
||||
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
|
||||
return value
|
||||
}
|
||||
|
||||
// OnboardUserRequest onboardUser 请求
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
@@ -147,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
|
||||
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
|
||||
if r.PaidTier == nil {
|
||||
return nil
|
||||
}
|
||||
return r.PaidTier.AvailableCredits
|
||||
}
|
||||
|
||||
// Client Antigravity API 客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
@@ -514,7 +590,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +613,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +663,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
|
||||
func TestGetTier_PaidTier优先(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||
PaidTier: &TierInfo{ID: "g1-pro-tier"},
|
||||
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
|
||||
}
|
||||
if got := resp.GetTier(); got != "g1-pro-tier" {
|
||||
t.Errorf("应返回 paidTier: got %s", got)
|
||||
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
|
||||
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||
PaidTier: &TierInfo{ID: ""},
|
||||
PaidTier: &PaidTierInfo{ID: ""},
|
||||
}
|
||||
// paidTier.ID 为空时应回退到 currentTier
|
||||
if got := resp.GetTier(); got != "free-tier" {
|
||||
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableCredits(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
PaidTier: &PaidTierInfo{
|
||||
ID: "g1-pro-tier",
|
||||
AvailableCredits: []AvailableCredit{
|
||||
{
|
||||
CreditType: "GOOGLE_ONE_AI",
|
||||
CreditAmount: "25",
|
||||
MinimumCreditAmountForUsage: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
credits := resp.GetAvailableCredits()
|
||||
if len(credits) != 1 {
|
||||
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
|
||||
}
|
||||
if credits[0].GetAmount() != 25 {
|
||||
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
|
||||
}
|
||||
if credits[0].GetMinimumAmount() != 5 {
|
||||
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTier_两者都为nil(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{}
|
||||
if got := resp.GetTier(); got != "" {
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||
var defaultUserAgentVersion = "1.20.5"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
|
||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package apicompat
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
|
||||
@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
|
||||
})
|
||||
}
|
||||
|
||||
// Accepted 返回异步接受响应 (HTTP 202)
|
||||
func Accepted(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusAccepted, Response{
|
||||
Code: 0,
|
||||
Message: "accepted",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
|
||||
@@ -81,6 +81,22 @@ type ModelStat struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// EndpointStat represents usage statistics for a single request endpoint.
|
||||
type EndpointStat struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupUsageSummary represents today's and cumulative cost for a single group.
|
||||
type GroupUsageSummary struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
TodayCost float64 `json:"today_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
type GroupStat struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
@@ -116,6 +132,26 @@ type UserSpendingRankingItem struct {
|
||||
type UserSpendingRankingResponse struct {
|
||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
|
||||
type UserBreakdownItem struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||
type UserBreakdownDimension struct {
|
||||
GroupID int64 // filter by group_id (>0 to enable)
|
||||
Model string // filter by model name (non-empty to enable)
|
||||
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||
EndpointType string // "inbound", "upstream", or "path"
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
@@ -179,15 +215,18 @@ type UsageLogFilters struct {
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
@@ -254,7 +293,9 @@ type AccountUsageSummary struct {
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
Endpoints []EndpointStat `json:"endpoints"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||
}
|
||||
|
||||
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
117
backend/internal/repository/backup_s3_store.go
Normal file
117
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -20,6 +20,11 @@ const (
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||
|
||||
// Rate limit window durations — must match service.RateLimitWindow* constants.
|
||||
rateLimitWindow5h = 5 * time.Hour
|
||||
rateLimitWindow1d = 24 * time.Hour
|
||||
rateLimitWindow7d = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
@@ -90,17 +95,40 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
|
||||
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
|
||||
// with window expiration checking. If a window has expired, its usage is reset to cost
|
||||
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
|
||||
// IncrementRateLimitUsage semantics.
|
||||
//
|
||||
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
|
||||
updateRateLimitUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
|
||||
local now = tonumber(ARGV[3])
|
||||
local win5h = tonumber(ARGV[4])
|
||||
local win1d = tonumber(ARGV[5])
|
||||
local win7d = tonumber(ARGV[6])
|
||||
|
||||
-- Helper: check if window is expired and update usage + window accordingly
|
||||
-- Returns nothing, modifies the hash in-place.
|
||||
local function update_window(usage_field, window_field, window_duration)
|
||||
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
|
||||
if w == 0 or (now - w) >= window_duration then
|
||||
-- Window expired or never started: reset usage to cost, start new window
|
||||
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
|
||||
redis.call('HSET', KEYS[1], window_field, tostring(now))
|
||||
else
|
||||
-- Window still valid: accumulate
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
|
||||
end
|
||||
end
|
||||
|
||||
update_window('usage_5h', 'window_5h', win5h)
|
||||
update_window('usage_1d', 'window_1d', win1d)
|
||||
update_window('usage_7d', 'window_7d', win7d)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
@@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data
|
||||
|
||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
|
||||
now := time.Now().Unix()
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
|
||||
cost,
|
||||
int(rateLimitCacheTTL.Seconds()),
|
||||
now,
|
||||
int(rateLimitWindow5h.Seconds()),
|
||||
int(rateLimitWindow1d.Seconds()),
|
||||
int(rateLimitWindow7d.Seconds()),
|
||||
).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||
return err
|
||||
|
||||
@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
total, active, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = total
|
||||
out.ActiveAccountCount = active
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
c := counts[outGroups[i].ID]
|
||||
outGroups[i].AccountCount = c.Total
|
||||
outGroups[i].ActiveAccountCount = c.Active
|
||||
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
c := counts[outGroups[i].ID]
|
||||
outGroups[i].AccountCount = c.Total
|
||||
outGroups[i].ActiveAccountCount = c.Active
|
||||
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
c := counts[outGroups[i].ID]
|
||||
outGroups[i].AccountCount = c.Total
|
||||
outGroups[i].ActiveAccountCount = c.Active
|
||||
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
|
||||
var rateLimited int64
|
||||
err = scanSingleRow(ctx, r.sql,
|
||||
`SELECT COUNT(*),
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||
a.rate_limit_reset_at > NOW() OR
|
||||
a.overload_until > NOW() OR
|
||||
a.temp_unschedulable_until > NOW()
|
||||
))
|
||||
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
|
||||
WHERE ag.group_id = $1`,
|
||||
[]any{groupID}, &total, &active, &rateLimited)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
|
||||
counts = make(map[int64]int64, len(groupIDs))
|
||||
type groupAccountCounts struct {
|
||||
Total int64
|
||||
Active int64
|
||||
RateLimited int64
|
||||
}
|
||||
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
|
||||
counts = make(map[int64]groupAccountCounts, len(groupIDs))
|
||||
if len(groupIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
|
||||
`SELECT ag.group_id,
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||
a.rate_limit_reset_at > NOW() OR
|
||||
a.overload_until > NOW() OR
|
||||
a.temp_unschedulable_until > NOW()
|
||||
)) AS rate_limited
|
||||
FROM account_groups ag
|
||||
JOIN accounts a ON a.id = ag.account_id
|
||||
WHERE ag.group_id = ANY($1)
|
||||
GROUP BY ag.group_id`,
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var count int64
|
||||
if err = rows.Scan(&groupID, &count); err != nil {
|
||||
var c groupAccountCounts
|
||||
if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[groupID] = count
|
||||
counts[groupID] = c
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||
}
|
||||
@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), affected)
|
||||
|
||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||
|
||||
var usageLogInsertArgTypes = [...]string{
|
||||
"bigint",
|
||||
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"boolean",
|
||||
"timestamptz",
|
||||
}
|
||||
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*37)
|
||||
args := make([]any, 0, len(keys)*38)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*36)
|
||||
args := make([]any, 0, len(preparedList)*38)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
inboundEndpoint,
|
||||
upstreamEndpoint,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
},
|
||||
@@ -2139,7 +2161,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
actual_cost,
|
||||
requests,
|
||||
tokens,
|
||||
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost
|
||||
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost,
|
||||
COALESCE(SUM(requests) OVER (), 0) as total_requests,
|
||||
COALESCE(SUM(tokens) OVER (), 0) as total_tokens
|
||||
FROM user_spend
|
||||
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
||||
LIMIT $3
|
||||
@@ -2150,7 +2174,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
actual_cost,
|
||||
requests,
|
||||
tokens,
|
||||
total_actual_cost
|
||||
total_actual_cost,
|
||||
total_requests,
|
||||
total_tokens
|
||||
FROM ranked
|
||||
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
||||
`
|
||||
@@ -2168,9 +2194,11 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
|
||||
ranking := make([]UserSpendingRankingItem, 0)
|
||||
totalActualCost := 0.0
|
||||
totalRequests := int64(0)
|
||||
totalTokens := int64(0)
|
||||
for rows.Next() {
|
||||
var row UserSpendingRankingItem
|
||||
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil {
|
||||
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ranking = append(ranking, row)
|
||||
@@ -2182,6 +2210,8 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
return &UserSpendingRankingResponse{
|
||||
Ranking: ranking,
|
||||
TotalActualCost: totalActualCost,
|
||||
TotalRequests: totalRequests,
|
||||
TotalTokens: totalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2505,7 +2535,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
@@ -2970,6 +3000,120 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension.
|
||||
func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) {
|
||||
query := `
|
||||
SELECT
|
||||
COALESCE(ul.user_id, 0) as user_id,
|
||||
COALESCE(u.email, '') as email,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN users u ON u.id = ul.user_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
`
|
||||
args := []any{startTime, endTime}
|
||||
|
||||
if dim.GroupID > 0 {
|
||||
query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
|
||||
args = append(args, dim.GroupID)
|
||||
}
|
||||
if dim.Model != "" {
|
||||
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1)
|
||||
args = append(args, dim.Model)
|
||||
}
|
||||
if dim.Endpoint != "" {
|
||||
col := resolveEndpointColumn(dim.EndpointType)
|
||||
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
|
||||
args = append(args, dim.Endpoint)
|
||||
}
|
||||
|
||||
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
|
||||
if limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]usagestats.UserBreakdownItem, 0)
|
||||
for rows.Next() {
|
||||
var row usagestats.UserBreakdownItem
|
||||
if err := rows.Scan(
|
||||
&row.UserID,
|
||||
&row.Email,
|
||||
&row.Requests,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
|
||||
// todayStart is the start-of-day in the caller's timezone (UTC-based).
|
||||
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
|
||||
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
|
||||
// or a materialized view / pre-aggregation table for cumulative costs.
|
||||
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||
query := `
|
||||
SELECT
|
||||
g.id AS group_id,
|
||||
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
|
||||
FROM groups g
|
||||
LEFT JOIN usage_logs ul ON ul.group_id = g.id
|
||||
GROUP BY g.id
|
||||
`
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, todayStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
var results []usagestats.GroupUsageSummary
|
||||
for rows.Next() {
|
||||
var row usagestats.GroupUsageSummary
|
||||
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
|
||||
func resolveEndpointColumn(endpointType string) string {
|
||||
switch endpointType {
|
||||
case "upstream":
|
||||
return "ul.upstream_endpoint"
|
||||
case "path":
|
||||
return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"
|
||||
default:
|
||||
return "ul.inbound_endpoint"
|
||||
}
|
||||
}
|
||||
|
||||
// GetGlobalStats gets usage statistics for all users within a time range
|
||||
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
query := `
|
||||
@@ -2982,7 +3126,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at <= $2
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
|
||||
stats := &UsageStats{}
|
||||
@@ -3040,7 +3184,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
@@ -3080,6 +3224,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
stats.TotalAccountCost = &totalAccountCost
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
|
||||
start := time.Unix(0, 0).UTC()
|
||||
if filters.StartTime != nil {
|
||||
start = *filters.StartTime
|
||||
}
|
||||
end := time.Now().UTC()
|
||||
if filters.EndTime != nil {
|
||||
end = *filters.EndTime
|
||||
}
|
||||
|
||||
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if endpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
|
||||
endpoints = []EndpointStat{}
|
||||
}
|
||||
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if upstreamEndpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
|
||||
upstreamEndpoints = []EndpointStat{}
|
||||
}
|
||||
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if endpointPathErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
|
||||
endpointPaths = []EndpointStat{}
|
||||
}
|
||||
stats.Endpoints = endpoints
|
||||
stats.UpstreamEndpoints = upstreamEndpoints
|
||||
stats.EndpointPaths = endpointPaths
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -3092,6 +3265,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||
|
||||
// EndpointStat represents endpoint usage statistics row.
|
||||
type EndpointStat = usagestats.EndpointStat
|
||||
|
||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
|
||||
COUNT(*) AS requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`, endpointColumn, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||
args = append(args, userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]EndpointStat, 0)
|
||||
for rows.Next() {
|
||||
var row EndpointStat
|
||||
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
CONCAT(
|
||||
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
|
||||
' -> ',
|
||||
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
|
||||
) AS endpoint,
|
||||
COUNT(*) AS requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||
args = append(args, userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]EndpointStat, 0)
|
||||
for rows.Next() {
|
||||
var row EndpointStat
|
||||
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
|
||||
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
|
||||
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
}
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
@@ -3254,11 +3584,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||
if endpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
|
||||
endpoints = []EndpointStat{}
|
||||
}
|
||||
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||
if upstreamEndpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
|
||||
upstreamEndpoints = []EndpointStat{}
|
||||
}
|
||||
|
||||
resp = &AccountUsageStatsResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
Endpoints: endpoints,
|
||||
UpstreamEndpoints: upstreamEndpoints,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -3541,6 +3883,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
inboundEndpoint sql.NullString
|
||||
upstreamEndpoint sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
)
|
||||
@@ -3581,6 +3925,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&inboundEndpoint,
|
||||
&upstreamEndpoint,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
@@ -3656,6 +4002,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
if inboundEndpoint.Valid {
|
||||
log.InboundEndpoint = &inboundEndpoint.String
|
||||
}
|
||||
if upstreamEndpoint.Valid {
|
||||
log.UpstreamEndpoint = &upstreamEndpoint.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
29
backend/internal/repository/usage_log_repo_breakdown_test.go
Normal file
29
backend/internal/repository/usage_log_repo_breakdown_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveEndpointColumn(t *testing.T) {
|
||||
tests := []struct {
|
||||
endpointType string
|
||||
want string
|
||||
}{
|
||||
{"inbound", "ul.inbound_endpoint"},
|
||||
{"upstream", "ul.upstream_endpoint"},
|
||||
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
|
||||
{"", "ul.inbound_endpoint"}, // default
|
||||
{"unknown", "ul.inbound_endpoint"}, // fallback
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.endpointType, func(t *testing.T) {
|
||||
got := resolveEndpointColumn(tc.endpointType)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
sqlmock.AnyArg(), // upstream_endpoint
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
@@ -255,10 +259,10 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost", "total_requests", "total_tokens"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0, int64(30), int64(2600)).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0, int64(30), int64(2600)).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0, int64(30), int64(2600))
|
||||
|
||||
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||
WithArgs(start, end, 12).
|
||||
@@ -273,6 +277,8 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||
},
|
||||
TotalActualCost: 40.0,
|
||||
TotalRequests: 30,
|
||||
TotalTokens: 2600,
|
||||
}, got)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
@@ -376,6 +382,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
@@ -415,6 +423,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
@@ -454,6 +464,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
q := client.UserSubscription.Query()
|
||||
if userID != nil {
|
||||
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
if groupID != nil {
|
||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||
}
|
||||
if platform != "" {
|
||||
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
|
||||
}
|
||||
|
||||
// Status filtering with real-time expiration check
|
||||
now := time.Now()
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
|
||||
@@ -100,6 +100,10 @@ var ProviderSet = wire.NewSet(
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// Backup infrastructure
|
||||
NewPgDumper,
|
||||
NewS3BackupStoreFactory,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
ProvidePricingRemoteClient,
|
||||
|
||||
@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"frontend_url": "",
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
@@ -537,6 +538,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
@@ -922,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||
return 0, 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -1287,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
|
||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
@@ -1623,10 +1625,22 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
@@ -1772,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
|
||||
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubSettingRepo struct {
|
||||
all map[string]string
|
||||
|
||||
@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
|
||||
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
|
||||
@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -107,9 +107,9 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 数据库备份恢复
|
||||
registerBackupRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -195,6 +198,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||
dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown)
|
||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||
}
|
||||
}
|
||||
@@ -223,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
|
||||
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
|
||||
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
@@ -440,6 +446,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
backup := admin.Group("/backups")
|
||||
{
|
||||
// S3 存储配置
|
||||
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||
|
||||
// 定时备份配置
|
||||
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||
|
||||
// 备份操作
|
||||
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||
backup.GET("", h.Admin.Backup.ListBackups)
|
||||
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||
|
||||
// 恢复操作
|
||||
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||
{
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
|
||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
@@ -30,6 +30,7 @@ func RegisterGatewayRoutes(
|
||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
endpointNorm := handler.InboundEndpointMiddleware()
|
||||
|
||||
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
||||
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
||||
@@ -40,6 +41,7 @@ func RegisterGatewayRoutes(
|
||||
gateway.Use(bodyLimit)
|
||||
gateway.Use(clientRequestID)
|
||||
gateway.Use(opsErrorLogger)
|
||||
gateway.Use(endpointNorm)
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
@@ -80,6 +82,7 @@ func RegisterGatewayRoutes(
|
||||
gemini.Use(bodyLimit)
|
||||
gemini.Use(clientRequestID)
|
||||
gemini.Use(opsErrorLogger)
|
||||
gemini.Use(endpointNorm)
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(requireGroupGoogle)
|
||||
{
|
||||
@@ -90,11 +93,11 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
@@ -104,6 +107,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1.Use(bodyLimit)
|
||||
antigravityV1.Use(clientRequestID)
|
||||
antigravityV1.Use(opsErrorLogger)
|
||||
antigravityV1.Use(endpointNorm)
|
||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
antigravityV1.Use(requireGroupAnthropic)
|
||||
@@ -118,6 +122,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.Use(bodyLimit)
|
||||
antigravityV1Beta.Use(clientRequestID)
|
||||
antigravityV1Beta.Use(opsErrorLogger)
|
||||
antigravityV1Beta.Use(endpointNorm)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(requireGroupGoogle)
|
||||
@@ -132,6 +137,7 @@ func RegisterGatewayRoutes(
|
||||
soraV1.Use(soraBodyLimit)
|
||||
soraV1.Use(clientRequestID)
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(endpointNorm)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
soraV1.Use(requireGroupAnthropic)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
@@ -522,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
return requestedModel, false
|
||||
}
|
||||
// 精确匹配优先
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
return mappedModel, true
|
||||
}
|
||||
// 通配符匹配(最长优先)
|
||||
return matchWildcardMapping(mapping, requestedModel)
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
@@ -605,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
|
||||
return matchAntigravityWildcard(pattern, str)
|
||||
}
|
||||
|
||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
||||
// 如果没有匹配,返回原始字符串
|
||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
||||
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||
type patternMatch struct {
|
||||
pattern string
|
||||
@@ -622,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return requestedModel // 无匹配,返回原始模型名
|
||||
return requestedModel, false // 无匹配,返回原始模型名
|
||||
}
|
||||
|
||||
// 按 pattern 长度降序排序
|
||||
@@ -633,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
return matches[i].pattern < matches[j].pattern
|
||||
})
|
||||
|
||||
return matches[0].target
|
||||
return matches[0].target, true
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
@@ -651,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
@@ -766,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrockAPIKey() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
|
||||
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
|
||||
}
|
||||
|
||||
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
|
||||
func (a *Account) IsAPIKeyOrBedrock() bool {
|
||||
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
@@ -890,6 +901,22 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOveragesEnabled 检查 Antigravity 账号是否启用 AI Credits 超量请求。
|
||||
func (a *Account) IsOveragesEnabled() bool {
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["allow_overages"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
||||
//
|
||||
// 新字段:accounts.extra.openai_passthrough。
|
||||
@@ -1269,6 +1296,240 @@ func (a *Account) getExtraTime(key string) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||
func (a *Account) getExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||
func (a *Account) getExtraInt(key string) int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return int(parseExtraFloat64(v))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaDailyResetMode() string {
|
||||
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaDailyResetHour() int {
|
||||
return a.getExtraInt("quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaWeeklyResetMode() string {
|
||||
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
|
||||
func (a *Account) GetQuotaWeeklyResetDay() int {
|
||||
if a.Extra == nil {
|
||||
return 1
|
||||
}
|
||||
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
|
||||
return 1
|
||||
}
|
||||
return a.getExtraInt("quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaWeeklyResetHour() int {
|
||||
return a.getExtraInt("quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
|
||||
func (a *Account) GetQuotaResetTimezone() string {
|
||||
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
|
||||
return tz
|
||||
}
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if !after.Before(today) {
|
||||
return today.AddDate(0, 0, 1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
|
||||
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if now.Before(today) {
|
||||
return today.AddDate(0, 0, -1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
|
||||
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysForward := (day - currentDay + 7) % 7
|
||||
if daysForward == 0 && !after.Before(todayReset) {
|
||||
daysForward = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, daysForward)
|
||||
}
|
||||
|
||||
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
|
||||
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysBack := (currentDay - day + 7) % 7
|
||||
if daysBack == 0 && now.Before(todayReset) {
|
||||
daysBack = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, -daysBack)
|
||||
}
|
||||
|
||||
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
|
||||
// 在保存账号配置时调用
|
||||
func ComputeQuotaResetAt(extra map[string]any) {
|
||||
now := time.Now()
|
||||
tzName, _ := extra["quota_reset_timezone"].(string)
|
||||
if tzName == "" {
|
||||
tzName = "UTC"
|
||||
}
|
||||
tz, err := time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
|
||||
// 日配额固定重置时间
|
||||
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
|
||||
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedDailyReset(hour, tz, now)
|
||||
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_daily_reset_at")
|
||||
}
|
||||
|
||||
// 周配额固定重置时间
|
||||
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
|
||||
day := 1 // 默认周一
|
||||
if d, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day = int(parseExtraFloat64(d))
|
||||
}
|
||||
if day < 0 || day > 6 {
|
||||
day = 1
|
||||
}
|
||||
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
|
||||
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_weekly_reset_at")
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
|
||||
func ValidateQuotaResetConfig(extra map[string]any) error {
|
||||
if extra == nil {
|
||||
return nil
|
||||
}
|
||||
// 校验时区
|
||||
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
|
||||
}
|
||||
}
|
||||
// 日配额重置模式
|
||||
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 日配额重置小时
|
||||
if v, ok := extra["quota_daily_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_daily_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
// 周配额重置模式
|
||||
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 周配额重置星期几
|
||||
if v, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day := int(parseExtraFloat64(v))
|
||||
if day < 0 || day > 6 {
|
||||
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
|
||||
}
|
||||
}
|
||||
// 周配额重置小时
|
||||
if v, ok := extra["quota_weekly_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
@@ -1291,14 +1552,26 @@ func (a *Account) IsQuotaExceeded() bool {
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||
expired = a.isFixedDailyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
expired = a.isFixedWeeklyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 7*24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
516
backend/internal/service/account_quota_reset_test.go
Normal file
516
backend/internal/service/account_quota_reset_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Exactly at reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// After reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Reset at hour 0 (midnight), currently 23:59
|
||||
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||
got := nextFixedDailyReset(0, tz, after)
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// Before today's 9:00 → yesterday 9:00
|
||||
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// At exactly 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// After 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-23
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||
// Next Sunday = 2026-03-15
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday at 9:00 = 2026-03-09
|
||||
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedDailyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started after the most recent reset → not expired
|
||||
// (This test uses a time very close to "now", which is after the last reset)
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 3 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Invalid/Timezone",
|
||||
}}
|
||||
// Invalid timezone falls back to UTC
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedWeeklyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 1 minute ago → not expired
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 10 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-240 * time.Hour)
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ValidateQuotaResetConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_reset_timezone": "Not/A/Timezone",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "invalid",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(24),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "unknown",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(7),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_hour": float64(25),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||
// All boundary values should be valid
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(23),
|
||||
"quota_weekly_reset_day": float64(0), // Sunday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
|
||||
extra2 := map[string]any{
|
||||
"quota_daily_reset_hour": float64(0),
|
||||
"quota_weekly_reset_day": float64(6), // Saturday
|
||||
"quota_weekly_reset_hour": float64(23),
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ComputeQuotaResetAt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset hour should be 9 UTC
|
||||
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1), // Monday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset day should be Monday
|
||||
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// In Shanghai timezone, the hour should be 9
|
||||
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(12),
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Default timezone is UTC
|
||||
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(99),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Invalid hour → clamped to 0
|
||||
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||
}
|
||||
@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
// generateSessionString generates a Claude Code style session string.
|
||||
// The output format is determined by the UA version in claude.DefaultHeaders,
|
||||
// ensuring consistency between the user_id format and the UA sent to upstream.
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
hex64 := hex.EncodeToString(bytes)
|
||||
hex64 := hex.EncodeToString(b)
|
||||
sessionUUID := uuid.New().String()
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||
uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
|
||||
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
|
||||
}
|
||||
|
||||
// createTestPayload creates a Claude Code style test request payload
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -44,7 +45,11 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
|
||||
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||
@@ -100,6 +105,7 @@ type antigravityUsageCache struct {
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -108,11 +114,12 @@ const (
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -149,6 +156,25 @@ type AntigravityModelQuota struct {
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
|
||||
type AntigravityModelDetail struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
SupportsImages *bool `json:"supports_images,omitempty"`
|
||||
SupportsThinking *bool `json:"supports_thinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinking_budget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||
}
|
||||
|
||||
// AICredit 表示 Antigravity 账号的 AI Credits 余额信息。
|
||||
type AICredit struct {
|
||||
CreditType string `json:"credit_type,omitempty"`
|
||||
Amount float64 `json:"amount,omitempty"`
|
||||
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -164,6 +190,36 @@ type UsageInfo struct {
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
|
||||
// Antigravity 账号级信息
|
||||
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
|
||||
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
|
||||
|
||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||
|
||||
// Antigravity AI Credits 余额
|
||||
AICredits []AICredit `json:"ai_credits,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
// Antigravity 账号是否被上游禁止 (HTTP 403)
|
||||
IsForbidden bool `json:"is_forbidden,omitempty"`
|
||||
ForbiddenReason string `json:"forbidden_reason,omitempty"`
|
||||
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
|
||||
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
|
||||
|
||||
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
|
||||
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation)
|
||||
IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation)
|
||||
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401)
|
||||
|
||||
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
|
||||
// 获取 usage 时的错误信息(降级返回,而非 500)
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -392,23 +448,17 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.FiveHour == nil {
|
||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.FiveHour.WindowStats = windowStats
|
||||
if usage.FiveHour == nil {
|
||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.FiveHour.WindowStats = windowStatsFromAccountStats(stats)
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.SevenDay == nil {
|
||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.SevenDay.WindowStats = windowStats
|
||||
if usage.SevenDay == nil {
|
||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.SevenDay.WindowStats = windowStatsFromAccountStats(stats)
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
@@ -648,34 +698,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
// 1. 检查缓存
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
// 2. singleflight 防止并发击穿
|
||||
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
|
||||
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(等待期间可能已被填充)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
|
||||
recalcAntigravityRemainingSeconds(usage)
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
|
||||
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer fetchCancel()
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
|
||||
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
|
||||
if err != nil {
|
||||
degraded := buildAntigravityDegradedUsage(err)
|
||||
enrichUsageWithAccountError(degraded, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: degraded,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: fetchResult.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return fetchResult.UsageInfo, nil
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
usage, ok := result.(*UsageInfo)
|
||||
if !ok || usage == nil {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
|
||||
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
info.FiveHour.RemainingSeconds = remaining
|
||||
}
|
||||
}
|
||||
|
||||
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
|
||||
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
|
||||
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
|
||||
func antigravityCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
if info.IsForbidden {
|
||||
return apiCacheTTL // 封号/验证状态不会很快变
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
|
||||
func buildAntigravityDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
errMsg := fmt.Sprintf("usage API error: %v", err)
|
||||
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
|
||||
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// 从错误信息推断 error_code 和状态标记
|
||||
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
|
||||
errStr := err.Error()
|
||||
switch {
|
||||
case strings.Contains(errStr, "HTTP 401") ||
|
||||
strings.Contains(errStr, "UNAUTHENTICATED") ||
|
||||
strings.Contains(errStr, "invalid_grant"):
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case strings.Contains(errStr, "HTTP 429"):
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
|
||||
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
|
||||
//
|
||||
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
|
||||
//
|
||||
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
|
||||
//
|
||||
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
|
||||
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
|
||||
if info == nil || account == nil || account.Status != StatusError {
|
||||
return
|
||||
}
|
||||
msg := strings.ToLower(account.ErrorMessage)
|
||||
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
|
||||
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
|
||||
return
|
||||
}
|
||||
fbType := classifyForbiddenType(account.ErrorMessage)
|
||||
info.IsForbidden = true
|
||||
info.ForbiddenType = fbType
|
||||
info.ForbiddenReason = account.ErrorMessage
|
||||
info.NeedsVerify = fbType == forbiddenTypeValidation
|
||||
info.IsBanned = fbType == forbiddenTypeViolation
|
||||
info.ValidationURL = extractValidationURL(account.ErrorMessage)
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
info.NeedsReauth = false
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
@@ -815,13 +988,6 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
||||
}
|
||||
}
|
||||
|
||||
func hasMeaningfulWindowStats(stats *WindowStats) bool {
|
||||
if stats == nil {
|
||||
return false
|
||||
}
|
||||
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
|
||||
}
|
||||
|
||||
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
@@ -878,6 +1044,11 @@ func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now t
|
||||
}
|
||||
}
|
||||
|
||||
// 窗口已过期(resetAt 在 now 之前)→ 额度已重置,归零
|
||||
if progress.ResetsAt != nil && !now.Before(*progress.ResetsAt) {
|
||||
progress.Utilization = 0
|
||||
}
|
||||
|
||||
return progress
|
||||
}
|
||||
|
||||
|
||||
@@ -148,3 +148,54 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
|
||||
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexUsageProgressFromExtra_ZerosExpiredWindow(t *testing.T) {
|
||||
t.Parallel()
|
||||
now := time.Date(2026, 3, 16, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
t.Run("expired 5h window zeroes utilization", func(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"codex_5h_used_percent": 42.0,
|
||||
"codex_5h_reset_at": "2026-03-16T10:00:00Z", // 2h ago
|
||||
}
|
||||
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
|
||||
if progress == nil {
|
||||
t.Fatal("expected non-nil progress")
|
||||
}
|
||||
if progress.Utilization != 0 {
|
||||
t.Fatalf("expected Utilization=0 for expired window, got %v", progress.Utilization)
|
||||
}
|
||||
if progress.RemainingSeconds != 0 {
|
||||
t.Fatalf("expected RemainingSeconds=0, got %v", progress.RemainingSeconds)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("active 5h window keeps utilization", func(t *testing.T) {
|
||||
resetAt := now.Add(2 * time.Hour).Format(time.RFC3339)
|
||||
extra := map[string]any{
|
||||
"codex_5h_used_percent": 42.0,
|
||||
"codex_5h_reset_at": resetAt,
|
||||
}
|
||||
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
|
||||
if progress == nil {
|
||||
t.Fatal("expected non-nil progress")
|
||||
}
|
||||
if progress.Utilization != 42.0 {
|
||||
t.Fatalf("expected Utilization=42, got %v", progress.Utilization)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired 7d window zeroes utilization", func(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"codex_7d_used_percent": 88.0,
|
||||
"codex_7d_reset_at": "2026-03-15T00:00:00Z", // yesterday
|
||||
}
|
||||
progress := buildCodexUsageProgressFromExtra(extra, "7d", now)
|
||||
if progress == nil {
|
||||
t.Fatal("expected non-nil progress")
|
||||
}
|
||||
if progress.Utilization != 0 {
|
||||
t.Fatalf("expected Utilization=0 for expired 7d window, got %v", progress.Utilization)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchWildcardMapping(t *testing.T) {
|
||||
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mapping map[string]string
|
||||
requestedModel string
|
||||
expected string
|
||||
matched bool
|
||||
}{
|
||||
// 精确匹配优先于通配符
|
||||
{
|
||||
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5-exact",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 最长通配符优先
|
||||
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-series",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 单个通配符
|
||||
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: "claude-mapped",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: "gemini-3-flash",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// 空映射返回原始模型
|
||||
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
mapping: map[string]string{},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// Gemini 模型映射
|
||||
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash-preview",
|
||||
expected: "gemini-3-pro-high",
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
|
||||
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected || matched != tt.matched {
|
||||
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
|
||||
@@ -368,6 +368,10 @@ type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
type proxyQualityTarget struct {
|
||||
Target string
|
||||
URL string
|
||||
@@ -445,10 +449,6 @@ type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
subscriptionType = SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
// 限额字段:0 和 nil 都表示"无限制"
|
||||
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
||||
// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零)
|
||||
func normalizeLimit(limit *float64) *float64 {
|
||||
if limit == nil || *limit <= 0 {
|
||||
if limit == nil || *limit < 0 {
|
||||
return nil
|
||||
}
|
||||
return limit
|
||||
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SubscriptionType != "" {
|
||||
group.SubscriptionType = input.SubscriptionType
|
||||
}
|
||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
||||
if input.DailyLimitUSD != nil {
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
}
|
||||
if input.WeeklyLimitUSD != nil {
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
}
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||
// 前端始终发送这三个字段,无需 nil 守卫
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
@@ -1462,6 +1457,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
// 预计算固定时间重置的下次重置时间
|
||||
if account.Extra != nil {
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
@@ -1514,6 +1516,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wasOveragesEnabled := account.IsOveragesEnabled()
|
||||
|
||||
if input.Name != "" {
|
||||
account.Name = input.Name
|
||||
@@ -1527,7 +1530,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。
|
||||
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
|
||||
if input.Extra != nil {
|
||||
// 保留配额用量字段,防止编辑账号时意外重置
|
||||
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
||||
if v, ok := account.Extra[key]; ok {
|
||||
@@ -1535,6 +1540,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() {
|
||||
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
|
||||
// 清除 AICredits 限流 key
|
||||
if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok {
|
||||
delete(rawLimits, creditsExhaustedKey)
|
||||
}
|
||||
}
|
||||
if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() {
|
||||
delete(account.Extra, modelRateLimitsKey)
|
||||
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
|
||||
}
|
||||
// 校验并预计算固定时间重置的下次重置时间
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
|
||||
@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
|
||||
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
|
||||
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
|
||||
@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
|
||||
155
backend/internal/service/admin_service_overages_test.go
Normal file
155
backend/internal/service/admin_service_overages_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type updateAccountOveragesRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account
|
||||
updateCalls int
|
||||
}
|
||||
|
||||
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
|
||||
accountID := int64(101)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.False(t, updated.IsOveragesEnabled())
|
||||
|
||||
// 关闭 overages 后,AICredits key 应被清除
|
||||
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if ok {
|
||||
_, exists := rawLimits[creditsExhaustedKey]
|
||||
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
|
||||
}
|
||||
// 普通模型限流应保留
|
||||
require.True(t, ok)
|
||||
_, exists := rawLimits["claude-sonnet-4-5"]
|
||||
require.True(t, exists, "普通模型限流应保留")
|
||||
}
|
||||
|
||||
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
|
||||
accountID := int64(102)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"allow_overages": true,
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.True(t, updated.IsOveragesEnabled())
|
||||
|
||||
_, exists := repo.account.Extra[modelRateLimitsKey]
|
||||
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
|
||||
}
|
||||
|
||||
func TestUpdateAccount_EmptyExtraPayloadCanClearQuotaLimits(t *testing.T) {
|
||||
accountID := int64(103)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_weekly_limit": 40.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
// 显式空对象:语义是“清空 extra 中的可配置键”(例如关闭配额限制)
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.NotNil(t, repo.account.Extra)
|
||||
require.NotContains(t, repo.account.Extra, "quota_limit")
|
||||
require.NotContains(t, repo.account.Extra, "quota_daily_limit")
|
||||
require.NotContains(t, repo.account.Extra, "quota_weekly_limit")
|
||||
require.Len(t, repo.account.Extra, 0)
|
||||
}
|
||||
234
backend/internal/service/antigravity_credits_overages.go
Normal file
234
backend/internal/service/antigravity_credits_overages.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
|
||||
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
|
||||
creditsExhaustedKey = "AICredits"
|
||||
creditsExhaustedDuration = 5 * time.Hour
|
||||
)
|
||||
|
||||
type antigravity429Category string
|
||||
|
||||
const (
|
||||
antigravity429Unknown antigravity429Category = "unknown"
|
||||
antigravity429RateLimited antigravity429Category = "rate_limited"
|
||||
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
|
||||
)
|
||||
|
||||
var (
|
||||
antigravityQuotaExhaustedKeywords = []string{
|
||||
"quota_exhausted",
|
||||
"quota exhausted",
|
||||
}
|
||||
|
||||
creditsExhaustedKeywords = []string{
|
||||
"google_one_ai",
|
||||
"insufficient credit",
|
||||
"insufficient credits",
|
||||
"not enough credit",
|
||||
"not enough credits",
|
||||
"credit exhausted",
|
||||
"credits exhausted",
|
||||
"credit balance",
|
||||
"minimumcreditamountforusage",
|
||||
"minimum credit amount for usage",
|
||||
"minimum credit",
|
||||
}
|
||||
)
|
||||
|
||||
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
|
||||
func (a *Account) isCreditsExhausted() bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
return a.isRateLimitActiveForKey(creditsExhaustedKey)
|
||||
}
|
||||
|
||||
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
|
||||
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
|
||||
if account == nil || account.ID == 0 {
|
||||
return
|
||||
}
|
||||
resetAt := time.Now().Add(creditsExhaustedDuration)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||
return
|
||||
}
|
||||
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
|
||||
account.ID, resetAt.UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
|
||||
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
|
||||
if account == nil || account.ID == 0 || account.Extra == nil {
|
||||
return
|
||||
}
|
||||
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
|
||||
return
|
||||
}
|
||||
delete(rawLimits, creditsExhaustedKey)
|
||||
account.Extra[modelRateLimitsKey] = rawLimits
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||
modelRateLimitsKey: rawLimits,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
|
||||
func classifyAntigravity429(body []byte) antigravity429Category {
|
||||
if len(body) == 0 {
|
||||
return antigravity429Unknown
|
||||
}
|
||||
lowerBody := strings.ToLower(string(body))
|
||||
for _, keyword := range antigravityQuotaExhaustedKeywords {
|
||||
if strings.Contains(lowerBody, keyword) {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
|
||||
return antigravity429RateLimited
|
||||
}
|
||||
return antigravity429Unknown
|
||||
}
|
||||
|
||||
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
|
||||
func injectEnabledCreditTypes(body []byte) []byte {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil
|
||||
}
|
||||
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
|
||||
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
|
||||
modelKey := strings.TrimSpace(upstreamModelName)
|
||||
if modelKey != "" {
|
||||
return modelKey
|
||||
}
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||
if strings.TrimSpace(modelKey) != "" {
|
||||
return modelKey
|
||||
}
|
||||
return resolveAntigravityModelKey(requestedModel)
|
||||
}
|
||||
|
||||
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
||||
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
||||
if reqErr != nil || resp == nil {
|
||||
return false
|
||||
}
|
||||
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
||||
return false
|
||||
}
|
||||
if isURLLevelRateLimit(respBody) {
|
||||
return false
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
|
||||
return false
|
||||
}
|
||||
bodyLower := strings.ToLower(string(respBody))
|
||||
for _, keyword := range creditsExhaustedKeywords {
|
||||
if strings.Contains(bodyLower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type creditsOveragesRetryResult struct {
|
||||
handled bool
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
|
||||
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
|
||||
p antigravityRetryLoopParams,
|
||||
baseURL string,
|
||||
modelName string,
|
||||
waitDuration time.Duration,
|
||||
originalStatusCode int,
|
||||
respBody []byte,
|
||||
) *creditsOveragesRetryResult {
|
||||
creditsBody := injectEnabledCreditTypes(p.body)
|
||||
if creditsBody == nil {
|
||||
return &creditsOveragesRetryResult{handled: false}
|
||||
}
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
|
||||
p.prefix, modelKey, p.account.ID)
|
||||
|
||||
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
|
||||
p.prefix, modelKey, p.account.ID, err)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
|
||||
s.clearCreditsExhausted(p.ctx, p.account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
|
||||
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
|
||||
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
|
||||
}
|
||||
|
||||
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
|
||||
ctx context.Context,
|
||||
prefix string,
|
||||
modelKey string,
|
||||
account *Account,
|
||||
creditsResp *http.Response,
|
||||
reqErr error,
|
||||
) {
|
||||
var creditsRespBody []byte
|
||||
creditsStatusCode := 0
|
||||
if creditsResp != nil {
|
||||
creditsStatusCode = creditsResp.StatusCode
|
||||
if creditsResp.Body != nil {
|
||||
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
|
||||
_ = creditsResp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
|
||||
s.setCreditsExhausted(ctx, account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
|
||||
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
|
||||
return
|
||||
}
|
||||
if account != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
|
||||
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
|
||||
}
|
||||
}
|
||||
538
backend/internal/service/antigravity_credits_overages_test.go
Normal file
538
backend/internal/service/antigravity_credits_overages_test.go
Normal file
@@ -0,0 +1,538 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyAntigravity429(t *testing.T) {
|
||||
t.Run("明确配额耗尽", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("结构化限流", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("未知429", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"too many requests"}}`)
|
||||
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
|
||||
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.isCreditsExhausted())
|
||||
})
|
||||
|
||||
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.True(t, account.isCreditsExhausted())
|
||||
})
|
||||
|
||||
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.False(t, account.isCreditsExhausted())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Name: "acc-101",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-opus-4-6": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-opus-4-6",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Nil(t, result.switchError)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Name: "acc-102",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Name: "acc-103",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Name: "acc-104",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// 模型限流 + 积分耗尽 → 应触发切号错误
|
||||
require.Error(t, err)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
|
||||
},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Name: "acc-105",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
|
||||
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
|
||||
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
func TestShouldMarkCreditsExhausted(t *testing.T) {
|
||||
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
|
||||
})
|
||||
|
||||
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
|
||||
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("5xx 响应不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusInternalServerError}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("URL 级限流不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("结构化限流不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("含 credits 关键词时标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
for _, keyword := range []string{
|
||||
"Insufficient GOOGLE_ONE_AI credits",
|
||||
"insufficient credit balance",
|
||||
"not enough credits for this request",
|
||||
"Credits exhausted",
|
||||
"minimumCreditAmountForUsage requirement not met",
|
||||
} {
|
||||
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
|
||||
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
body := []byte(`{"error":{"message":"permission denied"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInjectEnabledCreditTypes(t *testing.T) {
|
||||
t.Run("正常 JSON 注入成功", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
|
||||
result := injectEnabledCreditTypes(body)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, string(result), `"enabledCreditTypes"`)
|
||||
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||
})
|
||||
|
||||
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
|
||||
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
|
||||
})
|
||||
|
||||
t.Run("空 body 返回 nil", func(t *testing.T) {
|
||||
require.Nil(t, injectEnabledCreditTypes([]byte{}))
|
||||
})
|
||||
|
||||
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
|
||||
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
|
||||
result := injectEnabledCreditTypes(body)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||
require.NotContains(t, string(result), `OLD`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearCreditsExhausted(t *testing.T) {
|
||||
t.Run("account 为 nil 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), nil)
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{"some_key": "value"},
|
||||
})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("无 AICredits key 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc.clearCreditsExhausted(context.Background(), account)
|
||||
require.Len(t, repo.extraUpdateCalls, 1)
|
||||
// AICredits key 应被删除
|
||||
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
_, exists := rawLimits[creditsExhaustedKey]
|
||||
require.False(t, exists, "AICredits key 应被删除")
|
||||
// 普通模型限流应保留
|
||||
_, exists = rawLimits["claude-sonnet-4-5"]
|
||||
require.True(t, exists, "普通模型限流应保留")
|
||||
})
|
||||
}
|
||||
@@ -188,9 +188,29 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
return &smartRetryResult{action: smartRetryActionContinueURL}
|
||||
}
|
||||
|
||||
category := antigravity429Unknown
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
category = classifyAntigravity429(respBody)
|
||||
}
|
||||
|
||||
// 判断是否触发智能重试
|
||||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||||
|
||||
// AI Credits 超量请求:
|
||||
// 仅在上游明确返回免费配额耗尽时才允许切换到 credits。
|
||||
if resp.StatusCode == http.StatusTooManyRequests &&
|
||||
category == antigravity429QuotaExhausted &&
|
||||
p.account.IsOveragesEnabled() &&
|
||||
!p.account.isCreditsExhausted() {
|
||||
result := s.attemptCreditsOveragesRetry(p, baseURL, modelName, waitDuration, resp.StatusCode, respBody)
|
||||
if result.handled && result.resp != nil {
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
resp: result.resp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||||
if shouldRateLimitModel {
|
||||
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
|
||||
@@ -532,14 +552,31 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
|
||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||
// 预检查:模型限流 + overages 启用 + 积分未耗尽 → 直接注入 AI Credits
|
||||
overagesInjected := false
|
||||
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
|
||||
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
|
||||
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
|
||||
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
|
||||
p.body = creditsBody
|
||||
overagesInjected = true
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
|
||||
p.prefix, p.requestedModel, p.account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// 预检查:如果账号已限流,直接返回切换信号
|
||||
if p.requestedModel != "" {
|
||||
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
||||
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
||||
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
||||
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||
if isSingleAccountRetry(p.ctx) {
|
||||
// 已注入积分的请求不再受普通模型限流预检查阻断。
|
||||
if overagesInjected {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: credits_injected_ignore_rate_limit remaining=%v model=%s account=%d",
|
||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||
} else if isSingleAccountRetry(p.ctx) {
|
||||
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
||||
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
||||
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||
} else {
|
||||
@@ -631,6 +668,15 @@ urlFallbackLoop:
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if overagesInjected && shouldMarkCreditsExhausted(resp, respBody, nil) {
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, "", p.requestedModel)
|
||||
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||
if policyErr != nil {
|
||||
@@ -884,7 +930,7 @@ func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParam
|
||||
case ErrorPolicyTempUnscheduled:
|
||||
slog.Info("temp_unschedulable_matched",
|
||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||||
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, IsStickySession: p.isStickySession}
|
||||
}
|
||||
return false, statusCode, nil
|
||||
}
|
||||
@@ -955,8 +1001,9 @@ type TestConnectionResult struct {
|
||||
MappedModel string // 实际使用的模型
|
||||
}
|
||||
|
||||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
||||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
||||
// TestConnection 测试 Antigravity 账号连接。
|
||||
// 复用 antigravityRetryLoop 的完整重试 / credits overages / 智能重试逻辑,
|
||||
// 与真实调度行为一致。差异:不做账号切换(测试指定账号)、不记录 ops 错误。
|
||||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||
|
||||
// 获取 token
|
||||
@@ -980,10 +1027,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
// 构建请求体
|
||||
var requestBody []byte
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
// Gemini 模型:直接使用 Gemini 格式
|
||||
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
||||
} else {
|
||||
// Claude 模型:使用协议转换
|
||||
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -996,64 +1041,63 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
baseURL := resolveAntigravityForwardBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("no antigravity forward base url configured")
|
||||
}
|
||||
availableURLs := []string{baseURL}
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// 调试日志:Test 请求信息
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("请求失败: %w", err)
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 解析流式响应,提取文本
|
||||
text := extractTextFromSSEResponse(respBody)
|
||||
|
||||
// 标记成功的 URL,下次优先使用
|
||||
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return &TestConnectionResult{
|
||||
Text: text,
|
||||
MappedModel: mappedModel,
|
||||
}, nil
|
||||
// 复用 antigravityRetryLoop:完整的重试 / credits overages / 智能重试
|
||||
prefix := fmt.Sprintf("[antigravity-Test] account=%d(%s)", account.ID, account.Name)
|
||||
p := antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
account: account,
|
||||
proxyURL: proxyURL,
|
||||
accessToken: accessToken,
|
||||
action: "streamGenerateContent",
|
||||
body: requestBody,
|
||||
c: nil, // 无 gin.Context → 跳过 ops 追踪
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
accountRepo: s.accountRepo,
|
||||
requestedModel: modelID,
|
||||
handleError: testConnectionHandleError,
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
result, err := s.antigravityRetryLoop(p)
|
||||
if err != nil {
|
||||
// AccountSwitchError → 测试时不切换账号,返回友好提示
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
if errors.As(err, &switchErr) {
|
||||
return nil, fmt.Errorf("该账号模型 %s 当前限流中,请稍后重试", switchErr.RateLimitedModel)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil || result.resp == nil {
|
||||
return nil, errors.New("upstream returned empty response")
|
||||
}
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(result.resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if result.resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", result.resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
text := extractTextFromSSEResponse(respBody)
|
||||
return &TestConnectionResult{Text: text, MappedModel: mappedModel}, nil
|
||||
}
|
||||
|
||||
// testConnectionHandleError 是 TestConnection 使用的轻量 handleError 回调。
|
||||
// 仅记录日志,不做 ops 错误追踪或粘性会话清除。
|
||||
func testConnectionHandleError(
|
||||
_ context.Context, prefix string, account *Account,
|
||||
statusCode int, _ http.Header, body []byte,
|
||||
requestedModel string, _ int64, _ string, _ bool,
|
||||
) *handleModelRateLimitResult {
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"%s test_handle_error status=%d model=%s account=%d body=%s",
|
||||
prefix, statusCode, requestedModel, account.ID, truncateForLog(body, 200))
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||
@@ -3033,6 +3077,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
@@ -3065,6 +3125,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
return nil, ev.err
|
||||
}
|
||||
|
||||
lastDataAt = time.Now()
|
||||
|
||||
line := ev.line
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
@@ -3124,6 +3186,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if cw.Disconnected() {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
// SSE ping/keepalive:保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||
if !cw.Fprintf(":\n\n") {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity gemini), continuing to drain upstream for billing")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3849,6 +3924,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
@@ -3901,6 +3992,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// 处理 SSE 行,转换为 Claude 格式
|
||||
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
||||
if len(claudeEvents) > 0 {
|
||||
@@ -3923,6 +4016,20 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if cw.Disconnected() {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity claude), continuing to drain upstream for billing")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4253,6 +4360,22 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
||||
|
||||
@@ -4270,6 +4393,8 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
}
|
||||
|
||||
lastDataAt = time.Now()
|
||||
|
||||
line := ev.line
|
||||
|
||||
// 记录首 token 时间
|
||||
@@ -4295,6 +4420,20 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
||||
}
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
|
||||
case <-keepaliveCh:
|
||||
if cw.Disconnected() {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity upstream), continuing to drain upstream for billing")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,29 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
const (
|
||||
forbiddenTypeValidation = "validation"
|
||||
forbiddenTypeViolation = "violation"
|
||||
forbiddenTypeForbidden = "forbidden"
|
||||
|
||||
// 机器可读的错误码
|
||||
errorCodeForbidden = "forbidden"
|
||||
errorCodeUnauthenticated = "unauthenticated"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
errorCodeNetworkError = "network_error"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||
var forbiddenErr *antigravity.ForbiddenError
|
||||
if errors.As(err, &forbiddenErr) {
|
||||
now := time.Now()
|
||||
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||
return &QuotaResult{
|
||||
UsageInfo: &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
IsForbidden: true,
|
||||
ForbiddenReason: forbiddenErr.Body,
|
||||
ForbiddenType: fbType,
|
||||
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||
IsBanned: fbType == forbiddenTypeViolation,
|
||||
ErrorCode: errorCodeForbidden,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
@@ -52,15 +90,53 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。
|
||||
// 同时返回 LoadCodeAssistResponse,以便提取 AI Credits 余额。
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", "", nil
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized, loadResp
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
func normalizeTier(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
switch {
|
||||
case strings.Contains(lower, "ultra"):
|
||||
return "ULTRA"
|
||||
case strings.Contains(lower, "pro"):
|
||||
return "PRO"
|
||||
case strings.Contains(lower, "free"):
|
||||
return "FREE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo。
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||
SubscriptionTier: tierNormalized,
|
||||
SubscriptionTierRaw: tierRaw,
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
@@ -73,6 +149,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
|
||||
// 填充模型详细能力信息
|
||||
detail := &AntigravityModelDetail{
|
||||
DisplayName: modelInfo.DisplayName,
|
||||
SupportsImages: modelInfo.SupportsImages,
|
||||
SupportsThinking: modelInfo.SupportsThinking,
|
||||
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||
Recommended: modelInfo.Recommended,
|
||||
MaxTokens: modelInfo.MaxTokens,
|
||||
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||
}
|
||||
info.AntigravityQuotaDetails[modelName] = detail
|
||||
}
|
||||
|
||||
// 废弃模型转发规则
|
||||
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
@@ -94,6 +191,16 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
}
|
||||
}
|
||||
|
||||
if loadResp != nil {
|
||||
for _, credit := range loadResp.GetAvailableCredits() {
|
||||
info.AICredits = append(info.AICredits, AICredit{
|
||||
CreditType: credit.CreditType,
|
||||
Amount: credit.GetAmount(),
|
||||
MinimumBalance: credit.GetMinimumAmount(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -108,3 +215,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
|
||||
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||
func classifyForbiddenType(body string) string {
|
||||
lower := strings.ToLower(body)
|
||||
switch {
|
||||
case strings.Contains(lower, "validation_required") ||
|
||||
strings.Contains(lower, "verify your account") ||
|
||||
strings.Contains(lower, "validation_url"):
|
||||
return forbiddenTypeValidation
|
||||
case strings.Contains(lower, "terms of service") ||
|
||||
strings.Contains(lower, "violation"):
|
||||
return forbiddenTypeViolation
|
||||
default:
|
||||
return forbiddenTypeForbidden
|
||||
}
|
||||
}
|
||||
|
||||
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||
|
||||
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||
func extractValidationURL(body string) string {
|
||||
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||
var parsed struct {
|
||||
Error struct {
|
||||
Details []struct {
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||
for _, detail := range parsed.Error.Details {
|
||||
if u := detail.Metadata["validation_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 降级:正则匹配 URL
|
||||
lower := strings.ToLower(body)
|
||||
if !strings.Contains(lower, "validation") &&
|
||||
!strings.Contains(lower, "verify") &&
|
||||
!strings.Contains(lower, "appeal") {
|
||||
return ""
|
||||
}
|
||||
// 先解码常见转义再匹配
|
||||
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||
if m := urlPattern.FindString(normalized); m != "" {
|
||||
return m
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
522
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
522
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
@@ -0,0 +1,522 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeTier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{name: "empty string", raw: "", expected: ""},
|
||||
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTier(tt.raw)
|
||||
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildUsageInfo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func aqfBoolPtr(v bool) *bool { return &v }
|
||||
func aqfIntPtr(v int) *int { return &v }
|
||||
|
||||
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.75,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
DisplayName: "Claude Sonnet 4",
|
||||
SupportsImages: aqfBoolPtr(true),
|
||||
SupportsThinking: aqfBoolPtr(false),
|
||||
ThinkingBudget: aqfIntPtr(0),
|
||||
Recommended: aqfBoolPtr(true),
|
||||
MaxTokens: aqfIntPtr(200000),
|
||||
MaxOutputTokens: aqfIntPtr(16384),
|
||||
SupportedMimeTypes: map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "2026-03-08T15:00:00Z",
|
||||
},
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
MaxTokens: aqfIntPtr(1000000),
|
||||
MaxOutputTokens: aqfIntPtr(65536),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||
|
||||
// AntigravityQuota
|
||||
require.Len(t, info.AntigravityQuota, 2)
|
||||
|
||||
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetQuota)
|
||||
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||
|
||||
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiQuota)
|
||||
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||
|
||||
// AntigravityQuotaDetails
|
||||
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||
|
||||
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetDetail)
|
||||
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||
|
||||
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiDetail)
|
||||
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||
require.Nil(t, geminiDetail.SupportsImages)
|
||||
require.Nil(t, geminiDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
require.Empty(t, info.AntigravityQuota)
|
||||
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||
require.Empty(t, info.AntigravityQuotaDetails)
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"model-without-quota": {
|
||||
DisplayName: "No Quota Model",
|
||||
// QuotaInfo is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||
// When the first priority model exists, it should be used for FiveHour
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.40,
|
||||
ResetTime: "2026-03-08T18:00:00Z",
|
||||
},
|
||||
},
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.80,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.60,
|
||||
ResetTime: "2026-03-08T14:00:00Z",
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only gemini-2.5-pro exists (third in priority list)
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
"other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.90,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// None of the priority models exist
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "", // empty reset time
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.0, // fully used
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 100, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0, // fully available
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_AICredits(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
loadResp := &antigravity.LoadCodeAssistResponse{
|
||||
PaidTier: &antigravity.PaidTierInfo{
|
||||
ID: "g1-pro-tier",
|
||||
AvailableCredits: []antigravity.AvailableCredit{
|
||||
{
|
||||
CreditType: "GOOGLE_ONE_AI",
|
||||
CreditAmount: "25",
|
||||
MinimumCreditAmountForUsage: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
|
||||
|
||||
require.Len(t, info.AICredits, 1)
|
||||
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
|
||||
require.Equal(t, 25.0, info.AICredits[0].Amount)
|
||||
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
forbiddenErr := &antigravity.ForbiddenError{
|
||||
StatusCode: 403,
|
||||
Body: "Access denied",
|
||||
}
|
||||
|
||||
// 验证 ForbiddenError 满足 errors.As
|
||||
var target *antigravity.ForbiddenError
|
||||
require.True(t, errors.As(forbiddenErr, &target))
|
||||
require.Equal(t, 403, target.StatusCode)
|
||||
require.Equal(t, "Access denied", target.Body)
|
||||
require.Contains(t, forbiddenErr.Error(), "403")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyForbiddenType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClassifyForbiddenType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "VALIDATION_REQUIRED keyword",
|
||||
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "verify your account",
|
||||
body: `Please verify your account to continue`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "contains validation_url field",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "terms of service violation",
|
||||
body: `Your account has been suspended for Terms of Service violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "violation keyword",
|
||||
body: `Account suspended due to policy violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "generic 403",
|
||||
body: `Access denied`,
|
||||
expected: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyForbiddenType(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractValidationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractValidationURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "structured validation_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||
expected: "https://accounts.google.com/verify?token=abc",
|
||||
},
|
||||
{
|
||||
name: "structured appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||
expected: "https://support.google.com/appeal/123",
|
||||
},
|
||||
{
|
||||
name: "validation_url takes priority over appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||
expected: "https://v.com",
|
||||
},
|
||||
{
|
||||
name: "fallback regex with verify keyword",
|
||||
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||
expected: "https://accounts.google.com/verify",
|
||||
},
|
||||
{
|
||||
name: "no URL in generic forbidden",
|
||||
body: `Access denied`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL present but no validation keywords",
|
||||
body: `Error at https://example.com/something`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode escaped ampersand",
|
||||
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractValidationURL(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,10 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
|
||||
return false
|
||||
}
|
||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
|
||||
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user