mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-09 09:34:46 +08:00
Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 | ||
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b |
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
|||||||
# Go 源代码文件
|
# Go 源代码文件
|
||||||
*.go text eol=lf
|
*.go text eol=lf
|
||||||
|
|
||||||
|
# 前端 源代码文件
|
||||||
|
*.ts text eol=lf
|
||||||
|
*.tsx text eol=lf
|
||||||
|
*.js text eol=lf
|
||||||
|
*.jsx text eol=lf
|
||||||
|
*.vue text eol=lf
|
||||||
|
|
||||||
# Shell 脚本
|
# Shell 脚本
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -76,6 +78,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -89,6 +93,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -102,6 +108,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
11
Dockerfile
11
Dockerfile
@@ -92,6 +92,7 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
|||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
|
su-exec \
|
||||||
libpq \
|
libpq \
|
||||||
zstd-libs \
|
zstd-libs \
|
||||||
lz4-libs \
|
lz4-libs \
|
||||||
@@ -120,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||||
|
|
||||||
# Switch to non-root user
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
USER sub2api
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
# Expose port (can be overridden by SERVER_PORT env var)
|
# Expose port (can be overridden by SERVER_PORT env var)
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
@@ -130,5 +132,6 @@ EXPOSE 8080
|
|||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
# Run the application
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
@@ -5,7 +5,12 @@
|
|||||||
# It only packages the pre-built binary, no compilation needed.
|
# It only packages the pre-built binary, no compilation needed.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
FROM alpine:3.19
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
|
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||||
LABEL description="Sub2API - AI API Gateway Platform"
|
LABEL description="Sub2API - AI API Gateway Platform"
|
||||||
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
|
|||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
curl \
|
curl \
|
||||||
|
su-exec \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||||
|
# restore work in the runtime container without requiring Docker socket access.
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||||
|
|
||||||
USER sub2api
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
30
README.md
30
README.md
@@ -8,27 +8,31 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||||
|
|
||||||
English | [中文](README_CN.md)
|
English | [中文](README_CN.md)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
Try Sub2API online: **https://demo.sub2api.org/**
|
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||||
|
|
||||||
| Email | Password |
|
| Email | Password |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -41,6 +45,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
- **Admin Dashboard** - Web interface for monitoring and management
|
- **Admin Dashboard** - Web interface for monitoring and management
|
||||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||||
|
|
||||||
|
## Don't Want to Self-Host?
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
## Ecosystem
|
## Ecosystem
|
||||||
|
|
||||||
Community projects that extend or integrate with Sub2API:
|
Community projects that extend or integrate with Sub2API:
|
||||||
@@ -61,10 +74,15 @@ Community projects that extend or integrate with Sub2API:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Documentation
|
## Nginx Reverse Proxy Note
|
||||||
|
|
||||||
- Dependency Security: `docs/dependency-security.md`
|
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
33
README_CN.md
33
README_CN.md
@@ -8,27 +8,30 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API 网关平台 - 订阅配额分发管理**
|
**AI API 网关平台 - 订阅配额分发管理**
|
||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||||
---
|
---
|
||||||
|
|
||||||
## 在线体验
|
## 在线体验
|
||||||
|
|
||||||
体验地址:**https://v2.pincc.ai/**
|
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||||
|
|
||||||
| 邮箱 | 密码 |
|
| 邮箱 | 密码 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||||
|
|
||||||
## 核心功能
|
## 核心功能
|
||||||
|
|
||||||
@@ -41,6 +44,15 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- **管理后台** - Web 界面进行监控和管理
|
- **管理后台** - Web 界面进行监控和管理
|
||||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||||
|
|
||||||
|
## 不想自建?试试官方中转
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
## 生态项目
|
## 生态项目
|
||||||
|
|
||||||
围绕 Sub2API 的社区扩展与集成项目:
|
围绕 Sub2API 的社区扩展与集成项目:
|
||||||
@@ -61,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 文档
|
## Nginx 反向代理注意事项
|
||||||
|
|
||||||
- 依赖安全:`docs/dependency-security.md`
|
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||||
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## OpenAI Responses 兼容注意事项
|
|
||||||
|
|
||||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
|
||||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||||
@@ -124,6 +123,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
@@ -132,16 +132,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
rpmCache := repository.NewRPMCache(redisClient)
|
rpmCache := repository.NewRPMCache(redisClient)
|
||||||
|
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||||
|
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
dataManagementService := service.NewDataManagementService()
|
dataManagementService := service.NewDataManagementService()
|
||||||
@@ -166,10 +168,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
@@ -232,7 +234,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
|||||||
adminSvc := newStubAdminService()
|
adminSvc := newStubAdminService()
|
||||||
|
|
||||||
userHandler := NewUserHandler(adminSvc, nil)
|
userHandler := NewUserHandler(adminSvc, nil)
|
||||||
groupHandler := NewGroupHandler(adminSvc)
|
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||||
proxyHandler := NewProxyHandler(adminSvc)
|
proxyHandler := NewProxyHandler(adminSvc)
|
||||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -98,12 +98,12 @@ func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
|||||||
expireDays = *req.ExpireDays
|
expireDays = *req.ExpireDays
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
response.Success(c, record)
|
response.Accepted(c, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||||
@@ -196,9 +196,10 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
response.Success(c, gin.H{"restored": true})
|
response.Accepted(c, record)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -512,6 +513,8 @@ func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
|||||||
payload := gin.H{
|
payload := gin.H{
|
||||||
"ranking": ranking.Ranking,
|
"ranking": ranking.Ranking,
|
||||||
"total_actual_cost": ranking.TotalActualCost,
|
"total_actual_cost": ranking.TotalActualCost,
|
||||||
|
"total_requests": ranking.TotalRequests,
|
||||||
|
"total_tokens": ranking.TotalTokens,
|
||||||
"start_date": startTime.Format("2006-01-02"),
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
}
|
}
|
||||||
@@ -602,3 +605,41 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
|||||||
c.Header("X-Snapshot-Cache", "miss")
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
response.Success(c, payload)
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||||
|
// GET /api/v1/admin/dashboard/user-breakdown
|
||||||
|
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||||
|
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
|
dim := usagestats.UserBreakdownDimension{}
|
||||||
|
if v := c.Query("group_id"); v != "" {
|
||||||
|
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
dim.GroupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dim.Model = c.Query("model")
|
||||||
|
dim.Endpoint = c.Query("endpoint")
|
||||||
|
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if v := c.Query("limit"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||||
|
c.Request.Context(), startTime, endTime, dim, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"users": stats,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
|||||||
return &usagestats.UserSpendingRankingResponse{
|
return &usagestats.UserSpendingRankingResponse{
|
||||||
Ranking: s.ranking,
|
Ranking: s.ranking,
|
||||||
TotalActualCost: s.rankingTotal,
|
TotalActualCost: s.rankingTotal,
|
||||||
|
TotalRequests: 44,
|
||||||
|
TotalTokens: 1234,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,6 +166,8 @@ func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
require.Equal(t, 50, repo.rankingLimit)
|
require.Equal(t, 50, repo.rankingLimit)
|
||||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
|||||||
@@ -0,0 +1,203 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- mock repo ---
|
||||||
|
|
||||||
|
type userBreakdownRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
capturedDim usagestats.UserBreakdownDimension
|
||||||
|
capturedLimit int
|
||||||
|
result []usagestats.UserBreakdownItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||||
|
_ context.Context, _, _ time.Time,
|
||||||
|
dim usagestats.UserBreakdownDimension, limit int,
|
||||||
|
) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
r.capturedDim = dim
|
||||||
|
r.capturedLimit = limit
|
||||||
|
if r.result != nil {
|
||||||
|
return r.result, nil
|
||||||
|
}
|
||||||
|
return []usagestats.UserBreakdownItem{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
h := NewDashboardHandler(svc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- tests ---
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 100, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
// limit > 200 should fall back to default 50
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{
|
||||||
|
result: []usagestats.UserBreakdownItem{
|
||||||
|
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||||
|
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
StartDate string `json:"start_date"`
|
||||||
|
EndDate string `json:"end_date"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, resp.Code)
|
||||||
|
require.Len(t, resp.Data.Users, 2)
|
||||||
|
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||||
|
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||||
|
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||||
|
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||||
|
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||||
|
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, resp.Data.Users)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
}
|
||||||
@@ -1,11 +1,15 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,27 +17,80 @@ import (
|
|||||||
|
|
||||||
// GroupHandler handles admin group management
|
// GroupHandler handles admin group management
|
||||||
type GroupHandler struct {
|
type GroupHandler struct {
|
||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
|
dashboardService *service.DashboardService
|
||||||
|
groupCapacityService *service.GroupCapacityService
|
||||||
|
}
|
||||||
|
|
||||||
|
type optionalLimitField struct {
|
||||||
|
set bool
|
||||||
|
value *float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||||
|
f.set = true
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(data)
|
||||||
|
if bytes.Equal(trimmed, []byte("null")) {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var number float64
|
||||||
|
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var text string
|
||||||
|
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
number, err = strconv.ParseFloat(text, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||||
|
}
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||||
|
if !f.set {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if f.value != nil {
|
||||||
|
return f.value
|
||||||
|
}
|
||||||
|
zero := 0.0
|
||||||
|
return &zero
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupHandler creates a new admin group handler
|
// NewGroupHandler creates a new admin group handler
|
||||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||||
return &GroupHandler{
|
return &GroupHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
|
dashboardService: dashboardService,
|
||||||
|
groupCapacityService: groupCapacityService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroupRequest represents create group request
|
// CreateGroupRequest represents create group request
|
||||||
type CreateGroupRequest struct {
|
type CreateGroupRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
IsExclusive bool `json:"is_exclusive"`
|
IsExclusive bool `json:"is_exclusive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -62,16 +119,16 @@ type CreateGroupRequest struct {
|
|||||||
|
|
||||||
// UpdateGroupRequest represents update group request
|
// UpdateGroupRequest represents update group request
|
||||||
type UpdateGroupRequest struct {
|
type UpdateGroupRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
IsExclusive *bool `json:"is_exclusive"`
|
IsExclusive *bool `json:"is_exclusive"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -191,9 +248,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -244,9 +301,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -311,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
|||||||
_ = groupID // TODO: implement actual stats
|
_ = groupID // TODO: implement actual stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||||
|
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||||
|
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||||
|
userTZ := c.Query("timezone")
|
||||||
|
now := timezone.NowInUserLocation(userTZ)
|
||||||
|
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||||
|
|
||||||
|
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group usage summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||||
|
// GET /api/v1/admin/groups/capacity-summary
|
||||||
|
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||||
|
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group capacity summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroupAPIKeys handles getting API keys in a group
|
// GetGroupAPIKeys handles getting API keys in a group
|
||||||
// GET /api/v1/admin/groups/:id/api-keys
|
// GET /api/v1/admin/groups/:id/api-keys
|
||||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
FrontendURL: settings.FrontendURL,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -137,6 +138,7 @@ type UpdateSettingsRequest struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
@@ -326,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Frontend URL 验证
|
||||||
|
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||||
|
if req.FrontendURL != "" {
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||||
|
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 自定义菜单项验证
|
// 自定义菜单项验证
|
||||||
const (
|
const (
|
||||||
maxCustomMenuItems = 20
|
maxCustomMenuItems = 20
|
||||||
@@ -437,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
|
FrontendURL: req.FrontendURL,
|
||||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
@@ -531,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
|
FrontendURL: updatedSettings.FrontendURL,
|
||||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
TotpEnabled: updatedSettings.TotpEnabled,
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -614,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
|
if before.FrontendURL != after.FrontendURL {
|
||||||
|
changed = append(changed, "frontend_url")
|
||||||
|
}
|
||||||
if before.TotpEnabled != after.TotpEnabled {
|
if before.TotpEnabled != after.TotpEnabled {
|
||||||
changed = append(changed, "totp_enabled")
|
changed = append(changed, "totp_enabled")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
|
platform := c.Query("platform")
|
||||||
|
|
||||||
// Parse sorting parameters
|
// Parse sorting parameters
|
||||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||||
|
|
||||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set end time to end of day
|
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
t = t.AddDate(0, 0, 1)
|
||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||||
|
endTime = endTime.AddDate(0, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
period := c.DefaultQuery("period", "today")
|
period := c.DefaultQuery("period", "today")
|
||||||
switch period {
|
switch period {
|
||||||
|
|||||||
@@ -459,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||||
if frontendBaseURL == "" {
|
if frontendBaseURL == "" {
|
||||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||||
response.InternalError(c, "Password reset is not configured")
|
response.InternalError(c, "Password reset is not configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.MCPXMLInject,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
AccountCount: g.AccountCount,
|
AccountCount: g.AccountCount,
|
||||||
SortOrder: g.SortOrder,
|
ActiveAccountCount: g.ActiveAccountCount,
|
||||||
|
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||||
|
SortOrder: g.SortOrder,
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
@@ -523,6 +525,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
Model: l.Model,
|
Model: l.Model,
|
||||||
ServiceTier: l.ServiceTier,
|
ServiceTier: l.ServiceTier,
|
||||||
ReasoningEffort: l.ReasoningEffort,
|
ReasoningEffort: l.ReasoningEffort,
|
||||||
|
InboundEndpoint: l.InboundEndpoint,
|
||||||
|
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||||
GroupID: l.GroupID,
|
GroupID: l.GroupID,
|
||||||
SubscriptionID: l.SubscriptionID,
|
SubscriptionID: l.SubscriptionID,
|
||||||
InputTokens: l.InputTokens,
|
InputTokens: l.InputTokens,
|
||||||
|
|||||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
serviceTier := "priority"
|
serviceTier := "priority"
|
||||||
|
inboundEndpoint := "/v1/chat/completions"
|
||||||
|
upstreamEndpoint := "/v1/responses"
|
||||||
log := &service.UsageLog{
|
log := &service.UsageLog{
|
||||||
RequestID: "req_3",
|
RequestID: "req_3",
|
||||||
Model: "gpt-5.4",
|
Model: "gpt-5.4",
|
||||||
ServiceTier: &serviceTier,
|
ServiceTier: &serviceTier,
|
||||||
|
InboundEndpoint: &inboundEndpoint,
|
||||||
|
UpstreamEndpoint: &upstreamEndpoint,
|
||||||
AccountRateMultiplier: f64Ptr(1.5),
|
AccountRateMultiplier: f64Ptr(1.5),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
|
|
||||||
require.NotNil(t, userDTO.ServiceTier)
|
require.NotNil(t, userDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||||
|
require.NotNil(t, userDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.ServiceTier)
|
require.NotNil(t, adminDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||||
|
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
|
|||||||
@@ -122,9 +122,11 @@ type AdminGroup struct {
|
|||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
AccountCount int64 `json:"account_count,omitempty"`
|
AccountCount int64 `json:"account_count,omitempty"`
|
||||||
|
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||||
|
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||||
|
|
||||||
// 分组排序
|
// 分组排序
|
||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
@@ -334,9 +336,13 @@ type UsageLog struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||||
ServiceTier *string `json:"service_tier,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty"`
|
||||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
// ReasoningEffort is the request's reasoning effort level.
|
||||||
// nil means not provided / not applicable.
|
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||||
|
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||||
|
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||||
|
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||||
|
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||||
|
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
SubscriptionID *int64 `json:"subscription_id"`
|
SubscriptionID *int64 `json:"subscription_id"`
|
||||||
|
|||||||
174
backend/internal/handler/endpoint.go
Normal file
174
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Canonical inbound / upstream endpoint paths.
|
||||||
|
// All normalization and derivation reference this single set
|
||||||
|
// of constants — add new paths HERE when a new API surface
|
||||||
|
// is introduced.
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointMessages = "/v1/messages"
|
||||||
|
EndpointChatCompletions = "/v1/chat/completions"
|
||||||
|
EndpointResponses = "/v1/responses"
|
||||||
|
EndpointGeminiModels = "/v1beta/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// gin.Context keys used by the middleware and helpers below.
|
||||||
|
const (
|
||||||
|
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Normalization functions
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||||
|
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||||
|
//
|
||||||
|
// "/antigravity/v1/messages" → "/v1/messages"
|
||||||
|
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||||
|
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||||
|
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||||
|
func NormalizeInboundEndpoint(path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(path, EndpointChatCompletions):
|
||||||
|
return EndpointChatCompletions
|
||||||
|
case strings.Contains(path, EndpointMessages):
|
||||||
|
return EndpointMessages
|
||||||
|
case strings.Contains(path, EndpointResponses):
|
||||||
|
return EndpointResponses
|
||||||
|
case strings.Contains(path, EndpointGeminiModels):
|
||||||
|
return EndpointGeminiModels
|
||||||
|
default:
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||||
|
// account platform and the normalized inbound endpoint.
|
||||||
|
//
|
||||||
|
// Platform-specific rules:
|
||||||
|
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||||
|
// such as /v1/responses/compact preserved from the raw URL).
|
||||||
|
// - Anthropic → /v1/messages
|
||||||
|
// - Gemini → /v1beta/models
|
||||||
|
// - Sora → /v1/chat/completions
|
||||||
|
// - Antigravity routes may target either Claude or Gemini, so the
|
||||||
|
// inbound endpoint is used to distinguish.
|
||||||
|
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||||
|
inbound = strings.TrimSpace(inbound)
|
||||||
|
|
||||||
|
switch platform {
|
||||||
|
case service.PlatformOpenAI:
|
||||||
|
// OpenAI forwards everything to the Responses API.
|
||||||
|
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||||
|
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||||
|
return EndpointResponses + suffix
|
||||||
|
}
|
||||||
|
return EndpointResponses
|
||||||
|
|
||||||
|
case service.PlatformAnthropic:
|
||||||
|
return EndpointMessages
|
||||||
|
|
||||||
|
case service.PlatformGemini:
|
||||||
|
return EndpointGeminiModels
|
||||||
|
|
||||||
|
case service.PlatformSora:
|
||||||
|
return EndpointChatCompletions
|
||||||
|
|
||||||
|
case service.PlatformAntigravity:
|
||||||
|
// Antigravity accounts serve both Claude and Gemini.
|
||||||
|
if inbound == EndpointGeminiModels {
|
||||||
|
return EndpointGeminiModels
|
||||||
|
}
|
||||||
|
return EndpointMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown platform — fall back to inbound.
|
||||||
|
return inbound
|
||||||
|
}
|
||||||
|
|
||||||
|
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||||
|
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||||
|
// Returns "" when there is no meaningful suffix.
|
||||||
|
func responsesSubpathSuffix(rawPath string) string {
|
||||||
|
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||||
|
idx := strings.LastIndex(trimmed, "/responses")
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
suffix := trimmed[idx+len("/responses"):]
|
||||||
|
if suffix == "" || suffix == "/" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(suffix, "/") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Middleware
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||||
|
// canonical inbound endpoint in gin.Context so that every handler in
|
||||||
|
// the chain can read it via GetInboundEndpoint.
|
||||||
|
//
|
||||||
|
// Apply this middleware to all gateway route groups.
|
||||||
|
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
path := c.FullPath()
|
||||||
|
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
path = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Context helpers — used by handlers before building
|
||||||
|
// RecordUsageInput / RecordUsageLongContextInput.
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||||
|
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||||
|
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||||
|
func GetInboundEndpoint(c *gin.Context) string {
|
||||||
|
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: normalize on the fly.
|
||||||
|
path := ""
|
||||||
|
if c != nil {
|
||||||
|
path = c.FullPath()
|
||||||
|
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
path = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NormalizeInboundEndpoint(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||||
|
// and the account platform. Handlers call this after scheduling an
|
||||||
|
// account, passing account.Platform.
|
||||||
|
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||||
|
inbound := GetInboundEndpoint(c)
|
||||||
|
rawPath := ""
|
||||||
|
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||||
|
rawPath = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||||
|
}
|
||||||
159
backend/internal/handler/endpoint_test.go
Normal file
159
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() { gin.SetMode(gin.TestMode) }
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// NormalizeInboundEndpoint
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Direct canonical paths.
|
||||||
|
{"/v1/messages", EndpointMessages},
|
||||||
|
{"/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/v1/responses", EndpointResponses},
|
||||||
|
{"/v1beta/models", EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Prefixed paths (antigravity, openai, sora).
|
||||||
|
{"/antigravity/v1/messages", EndpointMessages},
|
||||||
|
{"/openai/v1/responses", EndpointResponses},
|
||||||
|
{"/openai/v1/responses/compact", EndpointResponses},
|
||||||
|
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Gin route patterns with wildcards.
|
||||||
|
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||||
|
{"/v1/responses/*subpath", EndpointResponses},
|
||||||
|
|
||||||
|
// Unknown path is returned as-is.
|
||||||
|
{"/v1/embeddings", "/v1/embeddings"},
|
||||||
|
{"", ""},
|
||||||
|
{" /v1/messages ", EndpointMessages},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.path, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// DeriveUpstreamEndpoint
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inbound string
|
||||||
|
rawPath string
|
||||||
|
platform string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Anthropic.
|
||||||
|
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||||
|
|
||||||
|
// Gemini.
|
||||||
|
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Sora.
|
||||||
|
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||||
|
|
||||||
|
// OpenAI — always /v1/responses.
|
||||||
|
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||||
|
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||||
|
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
|
||||||
|
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||||
|
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||||
|
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Unknown platform — passthrough.
|
||||||
|
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// responsesSubpathSuffix
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
raw string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/v1/responses", ""},
|
||||||
|
{"/v1/responses/", ""},
|
||||||
|
{"/v1/responses/compact", "/compact"},
|
||||||
|
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||||
|
{"/v1/messages", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.raw, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// InboundEndpointMiddleware + context helpers
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(InboundEndpointMiddleware())
|
||||||
|
|
||||||
|
var captured string
|
||||||
|
router.POST("/v1/messages", func(c *gin.Context) {
|
||||||
|
captured = GetInboundEndpoint(c)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, EndpointMessages, captured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||||
|
|
||||||
|
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||||
|
got := GetInboundEndpoint(c)
|
||||||
|
require.Equal(t, EndpointMessages, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||||
|
|
||||||
|
// Simulate middleware.
|
||||||
|
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||||
|
|
||||||
|
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||||
|
require.Equal(t, "/v1/responses/compact", got)
|
||||||
|
}
|
||||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -435,6 +442,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -444,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
@@ -637,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -706,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -739,6 +761,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -748,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
User: currentAPIKey.User,
|
User: currentAPIKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: currentSubscription,
|
Subscription: currentSubscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
@@ -913,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
|||||||
}
|
}
|
||||||
if s := c.Query("end_date"); s != "" {
|
if s := c.Query("end_date"); s != "" {
|
||||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return startTime, endTime
|
return startTime, endTime
|
||||||
|
|||||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||||
|
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||||
|
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||||
|
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||||
|
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||||
|
// 具体验证:
|
||||||
|
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||||
|
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||||
|
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||||
|
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||||
|
|
||||||
|
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||||
|
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||||
|
|
||||||
|
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||||
|
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||||
|
|
||||||
|
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||||
|
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||||
|
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||||
|
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||||
|
|
||||||
|
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx,
|
||||||
|
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||||
|
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||||
|
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||||
|
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||||
|
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
require.Contains(t, body, "event: message_start")
|
||||||
|
require.Contains(t, body, `"type":"error"`)
|
||||||
|
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||||
|
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||||
|
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||||
|
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||||
|
// c.Writer.Size() 仍为 -1
|
||||||
|
|
||||||
|
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||||
|
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||||
|
require.False(t, guardTriggered,
|
||||||
|
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||||
|
}
|
||||||
@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
|
||||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
|||||||
return []byte(`{
|
return []byte(`{
|
||||||
"model":"claude-3-5-sonnet-20241022",
|
"model":"claude-3-5-sonnet-20241022",
|
||||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||||
}`)
|
}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
System: []any{
|
System: []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||||
}
|
}
|
||||||
|
|
||||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||||
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
"system": []any{
|
"system": []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||||
})
|
})
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||||
|
|||||||
@@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
@@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
|||||||
@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||||
|
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||||
|
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||||
|
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "responses root maps to responses upstream",
|
||||||
|
path: "/v1/responses",
|
||||||
|
want: EndpointResponses,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses compact keeps compact suffix",
|
||||||
|
path: "/openai/v1/responses/compact",
|
||||||
|
want: "/v1/responses/compact",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses nested suffix preserved",
|
||||||
|
path: "/openai/v1/responses/compact/detail",
|
||||||
|
want: "/v1/responses/compact/detail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non responses path uses platform fallback",
|
||||||
|
path: "/v1/messages",
|
||||||
|
want: EndpointResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||||
|
|
||||||
|
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -362,6 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
@@ -738,6 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
@@ -1235,6 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
|
|||||||
@@ -26,6 +26,22 @@ const (
|
|||||||
opsStreamKey = "ops_stream"
|
opsStreamKey = "ops_stream"
|
||||||
opsRequestBodyKey = "ops_request_body"
|
opsRequestBodyKey = "ops_request_body"
|
||||||
opsAccountIDKey = "ops_account_id"
|
opsAccountIDKey = "ops_account_id"
|
||||||
|
|
||||||
|
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||||
|
opsErrContextCanceled = "context canceled"
|
||||||
|
opsErrNoAvailableAccounts = "no available accounts"
|
||||||
|
opsErrInvalidAPIKey = "invalid_api_key"
|
||||||
|
opsErrAPIKeyRequired = "api_key_required"
|
||||||
|
opsErrInsufficientBalance = "insufficient balance"
|
||||||
|
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||||
|
opsErrInsufficientQuota = "insufficient_quota"
|
||||||
|
|
||||||
|
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||||
|
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||||
|
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||||
|
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||||
|
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||||
|
opsCodeUserInactive = "USER_INACTIVE"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
|||||||
return errType
|
return errType
|
||||||
}
|
}
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE":
|
case opsCodeInsufficientBalance:
|
||||||
return "billing_error"
|
return "billing_error"
|
||||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "subscription_error"
|
return "subscription_error"
|
||||||
default:
|
default:
|
||||||
return "api_error"
|
return "api_error"
|
||||||
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||||
// Map billing/concurrency/response => request; scheduling => routing.
|
// Map billing/concurrency/response => request; scheduling => routing.
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "request"
|
return "request"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
case "upstream_error", "overloaded_error":
|
case "upstream_error", "overloaded_error":
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "api_error":
|
case "api_error":
|
||||||
if strings.Contains(msg, "no available accounts") {
|
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||||
return "routing"
|
return "routing"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "internal"
|
||||||
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
|||||||
|
|
||||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if phase == "billing" || phase == "concurrency" {
|
if phase == "billing" || phase == "concurrency" {
|
||||||
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
|||||||
|
|
||||||
// Check if context canceled errors should be ignored (client disconnects)
|
// Check if context canceled errors should be ignored (client disconnects)
|
||||||
if settings.IgnoreContextCanceled {
|
if settings.IgnoreContextCanceled {
|
||||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if "no available accounts" errors should be ignored
|
// Check if "no available accounts" errors should be ignored
|
||||||
if settings.IgnoreNoAvailableAccounts {
|
if settings.IgnoreNoAvailableAccounts {
|
||||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||||
if settings.IgnoreInvalidApiKeyErrors {
|
if settings.IgnoreInvalidApiKeyErrors {
|
||||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if insufficient balance errors should be ignored
|
||||||
|
if settings.IgnoreInsufficientBalanceErrors {
|
||||||
|
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||||
|
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||||
|
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
|||||||
@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
|
|||||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@@ -334,9 +334,23 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
|||||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set end time to end of day
|
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
t = t.AddDate(0, 0, 1)
|
||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 设置结束时间为当天结束
|
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
endTime = endTime.AddDate(0, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
// 使用 period 参数
|
// 使用 period 参数
|
||||||
period := c.DefaultQuery("period", "today")
|
period := c.DefaultQuery("period", "today")
|
||||||
|
|||||||
@@ -124,10 +124,68 @@ type IneligibleTier struct {
|
|||||||
type LoadCodeAssistResponse struct {
|
type LoadCodeAssistResponse struct {
|
||||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
|
||||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
|
||||||
|
type PaidTierInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
|
||||||
|
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 || string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] == '"' {
|
||||||
|
var id string
|
||||||
|
if err := json.Unmarshal(data, &id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.ID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
type alias PaidTierInfo
|
||||||
|
var raw alias
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*p = PaidTierInfo(raw)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableCredit 表示一条 AI Credits 余额记录。
|
||||||
|
type AvailableCredit struct {
|
||||||
|
CreditType string `json:"creditType,omitempty"`
|
||||||
|
CreditAmount string `json:"creditAmount,omitempty"`
|
||||||
|
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmount 将 creditAmount 解析为浮点数。
|
||||||
|
func (c *AvailableCredit) GetAmount() float64 {
|
||||||
|
if c.CreditAmount == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
|
||||||
|
func (c *AvailableCredit) GetMinimumAmount() float64 {
|
||||||
|
if c.MinimumCreditAmountForUsage == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// OnboardUserRequest onboardUser 请求
|
// OnboardUserRequest onboardUser 请求
|
||||||
type OnboardUserRequest struct {
|
type OnboardUserRequest struct {
|
||||||
TierID string `json:"tierId"`
|
TierID string `json:"tierId"`
|
||||||
@@ -157,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
|
||||||
|
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
|
||||||
|
if r.PaidTier == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.PaidTier.AvailableCredits
|
||||||
|
}
|
||||||
|
|
||||||
// Client Antigravity API 客户端
|
// Client Antigravity API 客户端
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
|
|||||||
func TestGetTier_PaidTier优先(t *testing.T) {
|
func TestGetTier_PaidTier优先(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{
|
resp := &LoadCodeAssistResponse{
|
||||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||||
PaidTier: &TierInfo{ID: "g1-pro-tier"},
|
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
|
||||||
}
|
}
|
||||||
if got := resp.GetTier(); got != "g1-pro-tier" {
|
if got := resp.GetTier(); got != "g1-pro-tier" {
|
||||||
t.Errorf("应返回 paidTier: got %s", got)
|
t.Errorf("应返回 paidTier: got %s", got)
|
||||||
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
|
|||||||
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{
|
resp := &LoadCodeAssistResponse{
|
||||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||||
PaidTier: &TierInfo{ID: ""},
|
PaidTier: &PaidTierInfo{ID: ""},
|
||||||
}
|
}
|
||||||
// paidTier.ID 为空时应回退到 currentTier
|
// paidTier.ID 为空时应回退到 currentTier
|
||||||
if got := resp.GetTier(); got != "free-tier" {
|
if got := resp.GetTier(); got != "free-tier" {
|
||||||
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableCredits(t *testing.T) {
|
||||||
|
resp := &LoadCodeAssistResponse{
|
||||||
|
PaidTier: &PaidTierInfo{
|
||||||
|
ID: "g1-pro-tier",
|
||||||
|
AvailableCredits: []AvailableCredit{
|
||||||
|
{
|
||||||
|
CreditType: "GOOGLE_ONE_AI",
|
||||||
|
CreditAmount: "25",
|
||||||
|
MinimumCreditAmountForUsage: "5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
credits := resp.GetAvailableCredits()
|
||||||
|
if len(credits) != 1 {
|
||||||
|
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
|
||||||
|
}
|
||||||
|
if credits[0].GetAmount() != 25 {
|
||||||
|
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
|
||||||
|
}
|
||||||
|
if credits[0].GetMinimumAmount() != 5 {
|
||||||
|
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetTier_两者都为nil(t *testing.T) {
|
func TestGetTier_两者都为nil(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{}
|
resp := &LoadCodeAssistResponse{}
|
||||||
if got := resp.GetTier(); got != "" {
|
if got := resp.GetTier(); got != "" {
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ const (
|
|||||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||||
var defaultUserAgentVersion = "1.20.4"
|
var defaultUserAgentVersion = "1.20.5"
|
||||||
|
|
||||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
var defaultClientSecret = "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"
|
var defaultClientSecret = "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"
|
||||||
|
|||||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||||
}
|
}
|
||||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||||
}
|
}
|
||||||
if SessionTTL != 30*time.Minute {
|
if SessionTTL != 30*time.Minute {
|
||||||
|
|||||||
@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accepted 返回异步接受响应 (HTTP 202)
|
||||||
|
func Accepted(c *gin.Context, data any) {
|
||||||
|
c.JSON(http.StatusAccepted, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: "accepted",
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Error 返回错误响应
|
// Error 返回错误响应
|
||||||
func Error(c *gin.Context, statusCode int, message string) {
|
func Error(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, Response{
|
c.JSON(statusCode, Response{
|
||||||
|
|||||||
@@ -81,6 +81,22 @@ type ModelStat struct {
|
|||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EndpointStat represents usage statistics for a single request endpoint.
|
||||||
|
type EndpointStat struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupUsageSummary represents today's and cumulative cost for a single group.
|
||||||
|
type GroupUsageSummary struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
TodayCost float64 `json:"today_cost"`
|
||||||
|
TotalCost float64 `json:"total_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
// GroupStat represents usage statistics for a single group
|
// GroupStat represents usage statistics for a single group
|
||||||
type GroupStat struct {
|
type GroupStat struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
@@ -116,6 +132,26 @@ type UserSpendingRankingItem struct {
|
|||||||
type UserSpendingRankingResponse struct {
|
type UserSpendingRankingResponse struct {
|
||||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
|
||||||
|
type UserBreakdownItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||||
|
type UserBreakdownDimension struct {
|
||||||
|
GroupID int64 // filter by group_id (>0 to enable)
|
||||||
|
Model string // filter by model name (non-empty to enable)
|
||||||
|
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||||
|
EndpointType string // "inbound", "upstream", or "path"
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||||
@@ -179,15 +215,18 @@ type UsageLogFilters struct {
|
|||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
type UsageStats struct {
|
type UsageStats struct {
|
||||||
TotalRequests int64 `json:"total_requests"`
|
TotalRequests int64 `json:"total_requests"`
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||||
|
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
@@ -254,7 +293,9 @@ type AccountUsageSummary struct {
|
|||||||
|
|
||||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||||
type AccountUsageStatsResponse struct {
|
type AccountUsageStatsResponse struct {
|
||||||
History []AccountUsageHistory `json:"history"`
|
History []AccountUsageHistory `json:"history"`
|
||||||
Summary AccountUsageSummary `json:"summary"`
|
Summary AccountUsageSummary `json:"summary"`
|
||||||
Models []ModelStat `json:"models"`
|
Models []ModelStat `json:"models"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
|||||||
|
|
||||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||||
|
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
|
||||||
data, err := io.ReadAll(body)
|
data, err := io.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("read body: %w", err)
|
return 0, fmt.Errorf("read body: %w", err)
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ const (
|
|||||||
billingCacheTTL = 5 * time.Minute
|
billingCacheTTL = 5 * time.Minute
|
||||||
billingCacheJitter = 30 * time.Second
|
billingCacheJitter = 30 * time.Second
|
||||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||||
|
|
||||||
|
// Rate limit window durations — must match service.RateLimitWindow* constants.
|
||||||
|
rateLimitWindow5h = 5 * time.Hour
|
||||||
|
rateLimitWindow1d = 24 * time.Hour
|
||||||
|
rateLimitWindow7d = 7 * 24 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||||
@@ -90,17 +95,40 @@ var (
|
|||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
|
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
|
||||||
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
|
// with window expiration checking. If a window has expired, its usage is reset to cost
|
||||||
|
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
|
||||||
|
// IncrementRateLimitUsage semantics.
|
||||||
|
//
|
||||||
|
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
|
||||||
updateRateLimitUsageScript = redis.NewScript(`
|
updateRateLimitUsageScript = redis.NewScript(`
|
||||||
local exists = redis.call('EXISTS', KEYS[1])
|
local exists = redis.call('EXISTS', KEYS[1])
|
||||||
if exists == 0 then
|
if exists == 0 then
|
||||||
return 0
|
return 0
|
||||||
end
|
end
|
||||||
local cost = tonumber(ARGV[1])
|
local cost = tonumber(ARGV[1])
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
|
local now = tonumber(ARGV[3])
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
|
local win5h = tonumber(ARGV[4])
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
|
local win1d = tonumber(ARGV[5])
|
||||||
|
local win7d = tonumber(ARGV[6])
|
||||||
|
|
||||||
|
-- Helper: check if window is expired and update usage + window accordingly
|
||||||
|
-- Returns nothing, modifies the hash in-place.
|
||||||
|
local function update_window(usage_field, window_field, window_duration)
|
||||||
|
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
|
||||||
|
if w == 0 or (now - w) >= window_duration then
|
||||||
|
-- Window expired or never started: reset usage to cost, start new window
|
||||||
|
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
|
||||||
|
redis.call('HSET', KEYS[1], window_field, tostring(now))
|
||||||
|
else
|
||||||
|
-- Window still valid: accumulate
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
update_window('usage_5h', 'window_5h', win5h)
|
||||||
|
update_window('usage_1d', 'window_1d', win1d)
|
||||||
|
update_window('usage_7d', 'window_7d', win7d)
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
@@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data
|
|||||||
|
|
||||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||||
key := billingRateLimitKey(keyID)
|
key := billingRateLimitKey(keyID)
|
||||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
|
now := time.Now().Unix()
|
||||||
|
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
|
||||||
|
cost,
|
||||||
|
int(rateLimitCacheTTL.Seconds()),
|
||||||
|
now,
|
||||||
|
int(rateLimitWindow5h.Seconds()),
|
||||||
|
int(rateLimitWindow1d.Seconds()),
|
||||||
|
int(rateLimitWindow7d.Seconds()),
|
||||||
|
).Result()
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
total, active, _ := r.GetAccountCount(ctx, out.ID)
|
||||||
out.AccountCount = count
|
out.AccountCount = total
|
||||||
|
out.ActiveAccountCount = active
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
|||||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := range outGroups {
|
for i := range outGroups {
|
||||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
c := counts[outGroups[i].ID]
|
||||||
|
outGroups[i].AccountCount = c.Total
|
||||||
|
outGroups[i].ActiveAccountCount = c.Active
|
||||||
|
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
|
|||||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := range outGroups {
|
for i := range outGroups {
|
||||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
c := counts[outGroups[i].ID]
|
||||||
|
outGroups[i].AccountCount = c.Total
|
||||||
|
outGroups[i].ActiveAccountCount = c.Active
|
||||||
|
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
|
|||||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := range outGroups {
|
for i := range outGroups {
|
||||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
c := counts[outGroups[i].ID]
|
||||||
|
outGroups[i].AccountCount = c.Total
|
||||||
|
outGroups[i].ActiveAccountCount = c.Active
|
||||||
|
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
|
||||||
var count int64
|
var rateLimited int64
|
||||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
err = scanSingleRow(ctx, r.sql,
|
||||||
return 0, err
|
`SELECT COUNT(*),
|
||||||
}
|
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
|
||||||
return count, nil
|
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||||
|
a.rate_limit_reset_at > NOW() OR
|
||||||
|
a.overload_until > NOW() OR
|
||||||
|
a.temp_unschedulable_until > NOW()
|
||||||
|
))
|
||||||
|
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
|
||||||
|
WHERE ag.group_id = $1`,
|
||||||
|
[]any{groupID}, &total, &active, &rateLimited)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
return affectedUserIDs, nil
|
return affectedUserIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
|
type groupAccountCounts struct {
|
||||||
counts = make(map[int64]int64, len(groupIDs))
|
Total int64
|
||||||
|
Active int64
|
||||||
|
RateLimited int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
|
||||||
|
counts = make(map[int64]groupAccountCounts, len(groupIDs))
|
||||||
if len(groupIDs) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
return counts, nil
|
return counts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(
|
rows, err := r.sql.QueryContext(
|
||||||
ctx,
|
ctx,
|
||||||
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
|
`SELECT ag.group_id,
|
||||||
|
COUNT(*) AS total,
|
||||||
|
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
|
||||||
|
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||||
|
a.rate_limit_reset_at > NOW() OR
|
||||||
|
a.overload_until > NOW() OR
|
||||||
|
a.temp_unschedulable_until > NOW()
|
||||||
|
)) AS rate_limited
|
||||||
|
FROM account_groups ag
|
||||||
|
JOIN accounts a ON a.id = ag.account_id
|
||||||
|
WHERE ag.group_id = ANY($1)
|
||||||
|
GROUP BY ag.group_id`,
|
||||||
pq.Array(groupIDs),
|
pq.Array(groupIDs),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
|||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var groupID int64
|
var groupID int64
|
||||||
var count int64
|
var c groupAccountCounts
|
||||||
if err = rows.Scan(&groupID, &count); err != nil {
|
if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
counts[groupID] = count
|
counts[groupID] = c
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
|||||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||||
s.Require().NoError(err, "GetAccountCount")
|
s.Require().NoError(err, "GetAccountCount")
|
||||||
s.Require().Equal(int64(2), count)
|
s.Require().Equal(int64(2), count)
|
||||||
}
|
}
|
||||||
@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
|||||||
}
|
}
|
||||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||||
|
|
||||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Zero(count)
|
s.Require().Zero(count)
|
||||||
}
|
}
|
||||||
@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
|||||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||||
|
|
||||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
s.Require().NoError(err, "GetAccountCount")
|
s.Require().NoError(err, "GetAccountCount")
|
||||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||||
}
|
}
|
||||||
@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
|||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Equal(int64(3), affected)
|
s.Require().Equal(int64(3), affected)
|
||||||
|
|
||||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
s.Require().Zero(count)
|
s.Require().Zero(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import (
|
|||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||||
|
|
||||||
var usageLogInsertArgTypes = [...]string{
|
var usageLogInsertArgTypes = [...]string{
|
||||||
"bigint",
|
"bigint",
|
||||||
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
|
|||||||
"text",
|
"text",
|
||||||
"text",
|
"text",
|
||||||
"text",
|
"text",
|
||||||
|
"text",
|
||||||
|
"text",
|
||||||
"boolean",
|
"boolean",
|
||||||
"timestamptz",
|
"timestamptz",
|
||||||
}
|
}
|
||||||
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
$8, $9, $10, $11,
|
$8, $9, $10, $11,
|
||||||
$12, $13,
|
$12, $13,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$14, $15, $16, $17, $18, $19,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(keys)*37)
|
args := make([]any, 0, len(keys)*38)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, key := range keys {
|
for idx, key := range keys {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(preparedList)*36)
|
args := make([]any, 0, len(preparedList)*38)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, prepared := range preparedList {
|
for idx, prepared := range preparedList {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
media_type,
|
media_type,
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
|
inbound_endpoint,
|
||||||
|
upstream_endpoint,
|
||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
$8, $9, $10, $11,
|
$8, $9, $10, $11,
|
||||||
$12, $13,
|
$12, $13,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$14, $15, $16, $17, $18, $19,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
`, prepared.args...)
|
`, prepared.args...)
|
||||||
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
mediaType := nullString(log.MediaType)
|
mediaType := nullString(log.MediaType)
|
||||||
serviceTier := nullString(log.ServiceTier)
|
serviceTier := nullString(log.ServiceTier)
|
||||||
reasoningEffort := nullString(log.ReasoningEffort)
|
reasoningEffort := nullString(log.ReasoningEffort)
|
||||||
|
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||||
|
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||||
|
|
||||||
var requestIDArg any
|
var requestIDArg any
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
mediaType,
|
mediaType,
|
||||||
serviceTier,
|
serviceTier,
|
||||||
reasoningEffort,
|
reasoningEffort,
|
||||||
|
inboundEndpoint,
|
||||||
|
upstreamEndpoint,
|
||||||
log.CacheTTLOverridden,
|
log.CacheTTLOverridden,
|
||||||
createdAt,
|
createdAt,
|
||||||
},
|
},
|
||||||
@@ -2139,7 +2161,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
|||||||
actual_cost,
|
actual_cost,
|
||||||
requests,
|
requests,
|
||||||
tokens,
|
tokens,
|
||||||
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost
|
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost,
|
||||||
|
COALESCE(SUM(requests) OVER (), 0) as total_requests,
|
||||||
|
COALESCE(SUM(tokens) OVER (), 0) as total_tokens
|
||||||
FROM user_spend
|
FROM user_spend
|
||||||
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
||||||
LIMIT $3
|
LIMIT $3
|
||||||
@@ -2150,7 +2174,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
|||||||
actual_cost,
|
actual_cost,
|
||||||
requests,
|
requests,
|
||||||
tokens,
|
tokens,
|
||||||
total_actual_cost
|
total_actual_cost,
|
||||||
|
total_requests,
|
||||||
|
total_tokens
|
||||||
FROM ranked
|
FROM ranked
|
||||||
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
||||||
`
|
`
|
||||||
@@ -2168,9 +2194,11 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
|||||||
|
|
||||||
ranking := make([]UserSpendingRankingItem, 0)
|
ranking := make([]UserSpendingRankingItem, 0)
|
||||||
totalActualCost := 0.0
|
totalActualCost := 0.0
|
||||||
|
totalRequests := int64(0)
|
||||||
|
totalTokens := int64(0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row UserSpendingRankingItem
|
var row UserSpendingRankingItem
|
||||||
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil {
|
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ranking = append(ranking, row)
|
ranking = append(ranking, row)
|
||||||
@@ -2182,6 +2210,8 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
|||||||
return &UserSpendingRankingResponse{
|
return &UserSpendingRankingResponse{
|
||||||
Ranking: ranking,
|
Ranking: ranking,
|
||||||
TotalActualCost: totalActualCost,
|
TotalActualCost: totalActualCost,
|
||||||
|
TotalRequests: totalRequests,
|
||||||
|
TotalTokens: totalTokens,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2505,7 +2535,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
|||||||
args = append(args, *filters.StartTime)
|
args = append(args, *filters.StartTime)
|
||||||
}
|
}
|
||||||
if filters.EndTime != nil {
|
if filters.EndTime != nil {
|
||||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||||
args = append(args, *filters.EndTime)
|
args = append(args, *filters.EndTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2970,6 +3000,120 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension.
|
||||||
|
func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
COALESCE(ul.user_id, 0) as user_id,
|
||||||
|
COALESCE(u.email, '') as email,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||||
|
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||||
|
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
|
||||||
|
FROM usage_logs ul
|
||||||
|
LEFT JOIN users u ON u.id = ul.user_id
|
||||||
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
|
`
|
||||||
|
args := []any{startTime, endTime}
|
||||||
|
|
||||||
|
if dim.GroupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, dim.GroupID)
|
||||||
|
}
|
||||||
|
if dim.Model != "" {
|
||||||
|
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1)
|
||||||
|
args = append(args, dim.Model)
|
||||||
|
}
|
||||||
|
if dim.Endpoint != "" {
|
||||||
|
col := resolveEndpointColumn(dim.EndpointType)
|
||||||
|
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
|
||||||
|
args = append(args, dim.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
|
||||||
|
if limit > 0 {
|
||||||
|
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
results = make([]usagestats.UserBreakdownItem, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var row usagestats.UserBreakdownItem
|
||||||
|
if err := rows.Scan(
|
||||||
|
&row.UserID,
|
||||||
|
&row.Email,
|
||||||
|
&row.Requests,
|
||||||
|
&row.TotalTokens,
|
||||||
|
&row.Cost,
|
||||||
|
&row.ActualCost,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
|
||||||
|
// todayStart is the start-of-day in the caller's timezone (UTC-based).
|
||||||
|
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
|
||||||
|
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
|
||||||
|
// or a materialized view / pre-aggregation table for cumulative costs.
|
||||||
|
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
g.id AS group_id,
|
||||||
|
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
|
||||||
|
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
|
||||||
|
FROM groups g
|
||||||
|
LEFT JOIN usage_logs ul ON ul.group_id = g.id
|
||||||
|
GROUP BY g.id
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, todayStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
var results []usagestats.GroupUsageSummary
|
||||||
|
for rows.Next() {
|
||||||
|
var row usagestats.GroupUsageSummary
|
||||||
|
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
|
||||||
|
func resolveEndpointColumn(endpointType string) string {
|
||||||
|
switch endpointType {
|
||||||
|
case "upstream":
|
||||||
|
return "ul.upstream_endpoint"
|
||||||
|
case "path":
|
||||||
|
return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"
|
||||||
|
default:
|
||||||
|
return "ul.inbound_endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetGlobalStats gets usage statistics for all users within a time range
|
// GetGlobalStats gets usage statistics for all users within a time range
|
||||||
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||||
query := `
|
query := `
|
||||||
@@ -2982,7 +3126,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
|||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at <= $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`
|
`
|
||||||
|
|
||||||
stats := &UsageStats{}
|
stats := &UsageStats{}
|
||||||
@@ -3040,7 +3184,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
args = append(args, *filters.StartTime)
|
args = append(args, *filters.StartTime)
|
||||||
}
|
}
|
||||||
if filters.EndTime != nil {
|
if filters.EndTime != nil {
|
||||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||||
args = append(args, *filters.EndTime)
|
args = append(args, *filters.EndTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3080,6 +3224,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
stats.TotalAccountCost = &totalAccountCost
|
stats.TotalAccountCost = &totalAccountCost
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
|
|
||||||
|
start := time.Unix(0, 0).UTC()
|
||||||
|
if filters.StartTime != nil {
|
||||||
|
start = *filters.StartTime
|
||||||
|
}
|
||||||
|
end := time.Now().UTC()
|
||||||
|
if filters.EndTime != nil {
|
||||||
|
end = *filters.EndTime
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||||
|
if endpointErr != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
|
||||||
|
endpoints = []EndpointStat{}
|
||||||
|
}
|
||||||
|
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||||
|
if upstreamEndpointErr != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
|
||||||
|
upstreamEndpoints = []EndpointStat{}
|
||||||
|
}
|
||||||
|
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||||
|
if endpointPathErr != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
|
||||||
|
endpointPaths = []EndpointStat{}
|
||||||
|
}
|
||||||
|
stats.Endpoints = endpoints
|
||||||
|
stats.UpstreamEndpoints = upstreamEndpoints
|
||||||
|
stats.EndpointPaths = endpointPaths
|
||||||
|
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3092,6 +3265,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
|||||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||||
|
|
||||||
|
// EndpointStat represents endpoint usage statistics row.
|
||||||
|
type EndpointStat = usagestats.EndpointStat
|
||||||
|
|
||||||
|
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||||
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
|
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT
|
||||||
|
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
|
||||||
|
COUNT(*) AS requests,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
|
%s
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`, endpointColumn, actualCostExpr)
|
||||||
|
|
||||||
|
args := []any{startTime, endTime}
|
||||||
|
if userID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||||
|
args = append(args, userID)
|
||||||
|
}
|
||||||
|
if apiKeyID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||||
|
args = append(args, apiKeyID)
|
||||||
|
}
|
||||||
|
if accountID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||||
|
args = append(args, accountID)
|
||||||
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, groupID)
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||||
|
args = append(args, model)
|
||||||
|
}
|
||||||
|
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||||
|
if billingType != nil {
|
||||||
|
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||||
|
args = append(args, int16(*billingType))
|
||||||
|
}
|
||||||
|
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
results = make([]EndpointStat, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var row EndpointStat
|
||||||
|
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||||
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
|
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT
|
||||||
|
CONCAT(
|
||||||
|
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
|
||||||
|
' -> ',
|
||||||
|
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
|
||||||
|
) AS endpoint,
|
||||||
|
COUNT(*) AS requests,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
|
%s
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`, actualCostExpr)
|
||||||
|
|
||||||
|
args := []any{startTime, endTime}
|
||||||
|
if userID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||||
|
args = append(args, userID)
|
||||||
|
}
|
||||||
|
if apiKeyID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||||
|
args = append(args, apiKeyID)
|
||||||
|
}
|
||||||
|
if accountID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||||
|
args = append(args, accountID)
|
||||||
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, groupID)
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||||
|
args = append(args, model)
|
||||||
|
}
|
||||||
|
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||||
|
if billingType != nil {
|
||||||
|
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||||
|
args = append(args, int16(*billingType))
|
||||||
|
}
|
||||||
|
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
results = make([]EndpointStat, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var row EndpointStat
|
||||||
|
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
|
||||||
|
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||||
|
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
|
||||||
|
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||||
|
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||||
@@ -3254,11 +3584,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
models = []ModelStat{}
|
models = []ModelStat{}
|
||||||
}
|
}
|
||||||
|
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||||
|
if endpointErr != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
|
||||||
|
endpoints = []EndpointStat{}
|
||||||
|
}
|
||||||
|
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||||
|
if upstreamEndpointErr != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
|
||||||
|
upstreamEndpoints = []EndpointStat{}
|
||||||
|
}
|
||||||
|
|
||||||
resp = &AccountUsageStatsResponse{
|
resp = &AccountUsageStatsResponse{
|
||||||
History: history,
|
History: history,
|
||||||
Summary: summary,
|
Summary: summary,
|
||||||
Models: models,
|
Models: models,
|
||||||
|
Endpoints: endpoints,
|
||||||
|
UpstreamEndpoints: upstreamEndpoints,
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -3541,6 +3883,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
mediaType sql.NullString
|
mediaType sql.NullString
|
||||||
serviceTier sql.NullString
|
serviceTier sql.NullString
|
||||||
reasoningEffort sql.NullString
|
reasoningEffort sql.NullString
|
||||||
|
inboundEndpoint sql.NullString
|
||||||
|
upstreamEndpoint sql.NullString
|
||||||
cacheTTLOverridden bool
|
cacheTTLOverridden bool
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
@@ -3581,6 +3925,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&mediaType,
|
&mediaType,
|
||||||
&serviceTier,
|
&serviceTier,
|
||||||
&reasoningEffort,
|
&reasoningEffort,
|
||||||
|
&inboundEndpoint,
|
||||||
|
&upstreamEndpoint,
|
||||||
&cacheTTLOverridden,
|
&cacheTTLOverridden,
|
||||||
&createdAt,
|
&createdAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
@@ -3656,6 +4002,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if reasoningEffort.Valid {
|
if reasoningEffort.Valid {
|
||||||
log.ReasoningEffort = &reasoningEffort.String
|
log.ReasoningEffort = &reasoningEffort.String
|
||||||
}
|
}
|
||||||
|
if inboundEndpoint.Valid {
|
||||||
|
log.InboundEndpoint = &inboundEndpoint.String
|
||||||
|
}
|
||||||
|
if upstreamEndpoint.Valid {
|
||||||
|
log.UpstreamEndpoint = &upstreamEndpoint.String
|
||||||
|
}
|
||||||
|
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
|
|||||||
29
backend/internal/repository/usage_log_repo_breakdown_test.go
Normal file
29
backend/internal/repository/usage_log_repo_breakdown_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveEndpointColumn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
endpointType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"inbound", "ul.inbound_endpoint"},
|
||||||
|
{"upstream", "ul.upstream_endpoint"},
|
||||||
|
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
|
||||||
|
{"", "ul.inbound_endpoint"}, // default
|
||||||
|
{"unknown", "ul.inbound_endpoint"}, // fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.endpointType, func(t *testing.T) {
|
||||||
|
got := resolveEndpointColumn(tc.endpointType)
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // media_type
|
sqlmock.AnyArg(), // media_type
|
||||||
sqlmock.AnyArg(), // service_tier
|
sqlmock.AnyArg(), // service_tier
|
||||||
sqlmock.AnyArg(), // reasoning_effort
|
sqlmock.AnyArg(), // reasoning_effort
|
||||||
|
sqlmock.AnyArg(), // inbound_endpoint
|
||||||
|
sqlmock.AnyArg(), // upstream_endpoint
|
||||||
log.CacheTTLOverridden,
|
log.CacheTTLOverridden,
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
|||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
serviceTier,
|
serviceTier,
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
|
sqlmock.AnyArg(),
|
||||||
|
sqlmock.AnyArg(),
|
||||||
log.CacheTTLOverridden,
|
log.CacheTTLOverridden,
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
@@ -255,10 +259,10 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
|||||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
end := start.Add(24 * time.Hour)
|
end := start.Add(24 * time.Hour)
|
||||||
|
|
||||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost", "total_requests", "total_tokens"}).
|
||||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0, int64(30), int64(2600)).
|
||||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0, int64(30), int64(2600)).
|
||||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0, int64(30), int64(2600))
|
||||||
|
|
||||||
mock.ExpectQuery("WITH user_spend AS \\(").
|
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||||
WithArgs(start, end, 12).
|
WithArgs(start, end, 12).
|
||||||
@@ -273,6 +277,8 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
|||||||
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||||
},
|
},
|
||||||
TotalActualCost: 40.0,
|
TotalActualCost: 40.0,
|
||||||
|
TotalRequests: 30,
|
||||||
|
TotalTokens: 2600,
|
||||||
}, got)
|
}, got)
|
||||||
require.NoError(t, mock.ExpectationsWereMet())
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
@@ -376,6 +382,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{Valid: true, String: "priority"},
|
sql.NullString{Valid: true, String: "priority"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
@@ -415,6 +423,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{Valid: true, String: "flex"},
|
sql.NullString{Valid: true, String: "flex"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
@@ -454,6 +464,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{Valid: true, String: "priority"},
|
sql.NullString{Valid: true, String: "priority"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
|||||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
q := client.UserSubscription.Query()
|
q := client.UserSubscription.Query()
|
||||||
if userID != nil {
|
if userID != nil {
|
||||||
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
|||||||
if groupID != nil {
|
if groupID != nil {
|
||||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||||
}
|
}
|
||||||
|
if platform != "" {
|
||||||
|
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
|
||||||
|
}
|
||||||
|
|
||||||
// Status filtering with real-time expiration check
|
// Status filtering with real-time expiration check
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
|||||||
group := s.mustCreateGroup("g-list")
|
group := s.mustCreateGroup("g-list")
|
||||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||||
|
|
||||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
|
||||||
s.Require().NoError(err, "List")
|
s.Require().NoError(err, "List")
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
|||||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
|||||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
|||||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||||
})
|
})
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||||
|
|||||||
@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"registration_email_suffix_whitelist": [],
|
"registration_email_suffix_whitelist": [],
|
||||||
"promo_code_enabled": true,
|
"promo_code_enabled": true,
|
||||||
"password_reset_enabled": false,
|
"password_reset_enabled": false,
|
||||||
|
"frontend_url": "",
|
||||||
"totp_enabled": false,
|
"totp_enabled": false,
|
||||||
"totp_encryption_key_configured": false,
|
"totp_encryption_key_configured": false,
|
||||||
"smtp_host": "smtp.example.com",
|
"smtp_host": "smtp.example.com",
|
||||||
@@ -923,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
|
|||||||
return false, errors.New("not implemented")
|
return false, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return 0, 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
@@ -1288,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
|
|||||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||||
@@ -1624,10 +1625,22 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -1773,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
|
|||||||
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubSettingRepo struct {
|
type stubSettingRepo struct {
|
||||||
all map[string]string
|
all map[string]string
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
|
|||||||
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||||
|
|||||||
@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||||
|
dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown)
|
||||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -226,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
{
|
{
|
||||||
groups.GET("", h.Admin.Group.List)
|
groups.GET("", h.Admin.Group.List)
|
||||||
groups.GET("/all", h.Admin.Group.GetAll)
|
groups.GET("/all", h.Admin.Group.GetAll)
|
||||||
|
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
|
||||||
|
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
|
||||||
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
|
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
|
||||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||||
groups.POST("", h.Admin.Group.Create)
|
groups.POST("", h.Admin.Group.Create)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func RegisterGatewayRoutes(
|
|||||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
||||||
clientRequestID := middleware.ClientRequestID()
|
clientRequestID := middleware.ClientRequestID()
|
||||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||||
|
endpointNorm := handler.InboundEndpointMiddleware()
|
||||||
|
|
||||||
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
||||||
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
||||||
@@ -40,6 +41,7 @@ func RegisterGatewayRoutes(
|
|||||||
gateway.Use(bodyLimit)
|
gateway.Use(bodyLimit)
|
||||||
gateway.Use(clientRequestID)
|
gateway.Use(clientRequestID)
|
||||||
gateway.Use(opsErrorLogger)
|
gateway.Use(opsErrorLogger)
|
||||||
|
gateway.Use(endpointNorm)
|
||||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||||
gateway.Use(requireGroupAnthropic)
|
gateway.Use(requireGroupAnthropic)
|
||||||
{
|
{
|
||||||
@@ -80,6 +82,7 @@ func RegisterGatewayRoutes(
|
|||||||
gemini.Use(bodyLimit)
|
gemini.Use(bodyLimit)
|
||||||
gemini.Use(clientRequestID)
|
gemini.Use(clientRequestID)
|
||||||
gemini.Use(opsErrorLogger)
|
gemini.Use(opsErrorLogger)
|
||||||
|
gemini.Use(endpointNorm)
|
||||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||||
gemini.Use(requireGroupGoogle)
|
gemini.Use(requireGroupGoogle)
|
||||||
{
|
{
|
||||||
@@ -90,11 +93,11 @@ func RegisterGatewayRoutes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI Responses API(不带v1前缀的别名)
|
// OpenAI Responses API(不带v1前缀的别名)
|
||||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||||
|
|
||||||
// Antigravity 模型列表
|
// Antigravity 模型列表
|
||||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||||
@@ -104,6 +107,7 @@ func RegisterGatewayRoutes(
|
|||||||
antigravityV1.Use(bodyLimit)
|
antigravityV1.Use(bodyLimit)
|
||||||
antigravityV1.Use(clientRequestID)
|
antigravityV1.Use(clientRequestID)
|
||||||
antigravityV1.Use(opsErrorLogger)
|
antigravityV1.Use(opsErrorLogger)
|
||||||
|
antigravityV1.Use(endpointNorm)
|
||||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||||
antigravityV1.Use(requireGroupAnthropic)
|
antigravityV1.Use(requireGroupAnthropic)
|
||||||
@@ -118,6 +122,7 @@ func RegisterGatewayRoutes(
|
|||||||
antigravityV1Beta.Use(bodyLimit)
|
antigravityV1Beta.Use(bodyLimit)
|
||||||
antigravityV1Beta.Use(clientRequestID)
|
antigravityV1Beta.Use(clientRequestID)
|
||||||
antigravityV1Beta.Use(opsErrorLogger)
|
antigravityV1Beta.Use(opsErrorLogger)
|
||||||
|
antigravityV1Beta.Use(endpointNorm)
|
||||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||||
antigravityV1Beta.Use(requireGroupGoogle)
|
antigravityV1Beta.Use(requireGroupGoogle)
|
||||||
@@ -132,6 +137,7 @@ func RegisterGatewayRoutes(
|
|||||||
soraV1.Use(soraBodyLimit)
|
soraV1.Use(soraBodyLimit)
|
||||||
soraV1.Use(clientRequestID)
|
soraV1.Use(clientRequestID)
|
||||||
soraV1.Use(opsErrorLogger)
|
soraV1.Use(opsErrorLogger)
|
||||||
|
soraV1.Use(endpointNorm)
|
||||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||||
soraV1.Use(requireGroupAnthropic)
|
soraV1.Use(requireGroupAnthropic)
|
||||||
|
|||||||
@@ -901,6 +901,22 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOveragesEnabled 检查 Antigravity 账号是否启用 AI Credits 超量请求。
|
||||||
|
func (a *Account) IsOveragesEnabled() bool {
|
||||||
|
if a.Platform != PlatformAntigravity {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Extra == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["allow_overages"]; ok {
|
||||||
|
if enabled, ok := v.(bool); ok {
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
||||||
//
|
//
|
||||||
// 新字段:accounts.extra.openai_passthrough。
|
// 新字段:accounts.extra.openai_passthrough。
|
||||||
|
|||||||
@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
|
|||||||
return normalized, nil
|
return normalized, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionString generates a Claude Code style session string
|
// generateSessionString generates a Claude Code style session string.
|
||||||
|
// The output format is determined by the UA version in claude.DefaultHeaders,
|
||||||
|
// ensuring consistency between the user_id format and the UA sent to upstream.
|
||||||
func generateSessionString() (string, error) {
|
func generateSessionString() (string, error) {
|
||||||
bytes := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
if _, err := rand.Read(b); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
hex64 := hex.EncodeToString(bytes)
|
hex64 := hex.EncodeToString(b)
|
||||||
sessionUUID := uuid.New().String()
|
sessionUUID := uuid.New().String()
|
||||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
|
||||||
|
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTestPayload creates a Claude Code style test request payload
|
// createTestPayload creates a Claude Code style test request payload
|
||||||
|
|||||||
@@ -45,7 +45,11 @@ type UsageLogRepository interface {
|
|||||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||||
|
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||||
|
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||||
|
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
|
||||||
|
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||||
@@ -164,6 +168,13 @@ type AntigravityModelDetail struct {
|
|||||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AICredit 表示 Antigravity 账号的 AI Credits 余额信息。
|
||||||
|
type AICredit struct {
|
||||||
|
CreditType string `json:"credit_type,omitempty"`
|
||||||
|
Amount float64 `json:"amount,omitempty"`
|
||||||
|
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||||
@@ -187,6 +198,9 @@ type UsageInfo struct {
|
|||||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||||
|
|
||||||
|
// Antigravity AI Credits 余额
|
||||||
|
AICredits []AICredit `json:"ai_credits,omitempty"`
|
||||||
|
|
||||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||||
|
|
||||||
@@ -434,23 +448,17 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
||||||
windowStats := windowStatsFromAccountStats(stats)
|
if usage.FiveHour == nil {
|
||||||
if hasMeaningfulWindowStats(windowStats) {
|
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||||
if usage.FiveHour == nil {
|
|
||||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
|
||||||
}
|
|
||||||
usage.FiveHour.WindowStats = windowStats
|
|
||||||
}
|
}
|
||||||
|
usage.FiveHour.WindowStats = windowStatsFromAccountStats(stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
||||||
windowStats := windowStatsFromAccountStats(stats)
|
if usage.SevenDay == nil {
|
||||||
if hasMeaningfulWindowStats(windowStats) {
|
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||||
if usage.SevenDay == nil {
|
|
||||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
|
||||||
}
|
|
||||||
usage.SevenDay.WindowStats = windowStats
|
|
||||||
}
|
}
|
||||||
|
usage.SevenDay.WindowStats = windowStatsFromAccountStats(stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
@@ -980,13 +988,6 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasMeaningfulWindowStats(stats *WindowStats) bool {
|
|
||||||
if stats == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
||||||
if len(extra) == 0 {
|
if len(extra) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -1043,6 +1044,11 @@ func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 窗口已过期(resetAt 在 now 之前)→ 额度已重置,归零
|
||||||
|
if progress.ResetsAt != nil && !now.Before(*progress.ResetsAt) {
|
||||||
|
progress.Utilization = 0
|
||||||
|
}
|
||||||
|
|
||||||
return progress
|
return progress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -148,3 +148,54 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
|
|||||||
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexUsageProgressFromExtra_ZerosExpiredWindow(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
now := time.Date(2026, 3, 16, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
t.Run("expired 5h window zeroes utilization", func(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"codex_5h_used_percent": 42.0,
|
||||||
|
"codex_5h_reset_at": "2026-03-16T10:00:00Z", // 2h ago
|
||||||
|
}
|
||||||
|
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
|
||||||
|
if progress == nil {
|
||||||
|
t.Fatal("expected non-nil progress")
|
||||||
|
}
|
||||||
|
if progress.Utilization != 0 {
|
||||||
|
t.Fatalf("expected Utilization=0 for expired window, got %v", progress.Utilization)
|
||||||
|
}
|
||||||
|
if progress.RemainingSeconds != 0 {
|
||||||
|
t.Fatalf("expected RemainingSeconds=0, got %v", progress.RemainingSeconds)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("active 5h window keeps utilization", func(t *testing.T) {
|
||||||
|
resetAt := now.Add(2 * time.Hour).Format(time.RFC3339)
|
||||||
|
extra := map[string]any{
|
||||||
|
"codex_5h_used_percent": 42.0,
|
||||||
|
"codex_5h_reset_at": resetAt,
|
||||||
|
}
|
||||||
|
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
|
||||||
|
if progress == nil {
|
||||||
|
t.Fatal("expected non-nil progress")
|
||||||
|
}
|
||||||
|
if progress.Utilization != 42.0 {
|
||||||
|
t.Fatalf("expected Utilization=42, got %v", progress.Utilization)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("expired 7d window zeroes utilization", func(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"codex_7d_used_percent": 88.0,
|
||||||
|
"codex_7d_reset_at": "2026-03-15T00:00:00Z", // yesterday
|
||||||
|
}
|
||||||
|
progress := buildCodexUsageProgressFromExtra(extra, "7d", now)
|
||||||
|
if progress == nil {
|
||||||
|
t.Fatal("expected non-nil progress")
|
||||||
|
}
|
||||||
|
if progress.Utilization != 0 {
|
||||||
|
t.Fatalf("expected Utilization=0 for expired 7d window, got %v", progress.Utilization)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -368,6 +368,10 @@ type ProxyExitInfoProber interface {
|
|||||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type groupExistenceBatchReader interface {
|
||||||
|
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
type proxyQualityTarget struct {
|
type proxyQualityTarget struct {
|
||||||
Target string
|
Target string
|
||||||
URL string
|
URL string
|
||||||
@@ -445,10 +449,6 @@ type userGroupRateBatchReader interface {
|
|||||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type groupExistenceBatchReader interface {
|
|
||||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
func NewAdminService(
|
func NewAdminService(
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
subscriptionType = SubscriptionTypeStandard
|
subscriptionType = SubscriptionTypeStandard
|
||||||
}
|
}
|
||||||
|
|
||||||
// 限额字段:0 和 nil 都表示"无限制"
|
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||||
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零)
|
||||||
func normalizeLimit(limit *float64) *float64 {
|
func normalizeLimit(limit *float64) *float64 {
|
||||||
if limit == nil || *limit <= 0 {
|
if limit == nil || *limit < 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return limit
|
return limit
|
||||||
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.SubscriptionType != "" {
|
if input.SubscriptionType != "" {
|
||||||
group.SubscriptionType = input.SubscriptionType
|
group.SubscriptionType = input.SubscriptionType
|
||||||
}
|
}
|
||||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||||
if input.DailyLimitUSD != nil {
|
// 前端始终发送这三个字段,无需 nil 守卫
|
||||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||||
}
|
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||||
if input.WeeklyLimitUSD != nil {
|
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
|
||||||
}
|
|
||||||
if input.MonthlyLimitUSD != nil {
|
|
||||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
|
||||||
}
|
|
||||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||||
if input.ImagePrice1K != nil {
|
if input.ImagePrice1K != nil {
|
||||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||||
@@ -1521,6 +1516,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
wasOveragesEnabled := account.IsOveragesEnabled()
|
||||||
|
|
||||||
if input.Name != "" {
|
if input.Name != "" {
|
||||||
account.Name = input.Name
|
account.Name = input.Name
|
||||||
@@ -1534,7 +1530,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if len(input.Credentials) > 0 {
|
if len(input.Credentials) > 0 {
|
||||||
account.Credentials = input.Credentials
|
account.Credentials = input.Credentials
|
||||||
}
|
}
|
||||||
if len(input.Extra) > 0 {
|
// Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。
|
||||||
|
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
|
||||||
|
if input.Extra != nil {
|
||||||
// 保留配额用量字段,防止编辑账号时意外重置
|
// 保留配额用量字段,防止编辑账号时意外重置
|
||||||
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
||||||
if v, ok := account.Extra[key]; ok {
|
if v, ok := account.Extra[key]; ok {
|
||||||
@@ -1542,6 +1540,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
account.Extra = input.Extra
|
account.Extra = input.Extra
|
||||||
|
if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() {
|
||||||
|
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
|
||||||
|
// 清除 AICredits 限流 key
|
||||||
|
if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok {
|
||||||
|
delete(rawLimits, creditsExhaustedKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() {
|
||||||
|
delete(account.Extra, modelRateLimitsKey)
|
||||||
|
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
|
||||||
|
}
|
||||||
// 校验并预计算固定时间重置的下次重置时间
|
// 校验并预计算固定时间重置的下次重置时间
|
||||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
|
|||||||
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
|
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
155
backend/internal/service/admin_service_overages_test.go
Normal file
155
backend/internal/service/admin_service_overages_test.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type updateAccountOveragesRepoStub struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
account *Account
|
||||||
|
updateCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
|
return r.account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
|
||||||
|
r.updateCalls++
|
||||||
|
r.account = account
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
|
||||||
|
accountID := int64(101)
|
||||||
|
repo := &updateAccountOveragesRepoStub{
|
||||||
|
account: &Account{
|
||||||
|
ID: accountID,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||||
|
},
|
||||||
|
creditsExhaustedKey: map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||||
|
Extra: map[string]any{
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||||
|
},
|
||||||
|
creditsExhaustedKey: map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.False(t, updated.IsOveragesEnabled())
|
||||||
|
|
||||||
|
// 关闭 overages 后,AICredits key 应被清除
|
||||||
|
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
_, exists := rawLimits[creditsExhaustedKey]
|
||||||
|
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
|
||||||
|
}
|
||||||
|
// 普通模型限流应保留
|
||||||
|
require.True(t, ok)
|
||||||
|
_, exists := rawLimits["claude-sonnet-4-5"]
|
||||||
|
require.True(t, exists, "普通模型限流应保留")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
|
||||||
|
accountID := int64(102)
|
||||||
|
repo := &updateAccountOveragesRepoStub{
|
||||||
|
account: &Account{
|
||||||
|
ID: accountID,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||||
|
Extra: map[string]any{
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
"allow_overages": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.True(t, updated.IsOveragesEnabled())
|
||||||
|
|
||||||
|
_, exists := repo.account.Extra[modelRateLimitsKey]
|
||||||
|
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccount_EmptyExtraPayloadCanClearQuotaLimits(t *testing.T) {
|
||||||
|
accountID := int64(103)
|
||||||
|
repo := &updateAccountOveragesRepoStub{
|
||||||
|
account: &Account{
|
||||||
|
ID: accountID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"quota_limit": 100.0,
|
||||||
|
"quota_daily_limit": 10.0,
|
||||||
|
"quota_weekly_limit": 40.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||||
|
// 显式空对象:语义是“清空 extra 中的可配置键”(例如关闭配额限制)
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.NotNil(t, repo.account.Extra)
|
||||||
|
require.NotContains(t, repo.account.Extra, "quota_limit")
|
||||||
|
require.NotContains(t, repo.account.Extra, "quota_daily_limit")
|
||||||
|
require.NotContains(t, repo.account.Extra, "quota_weekly_limit")
|
||||||
|
require.Len(t, repo.account.Extra, 0)
|
||||||
|
}
|
||||||
234
backend/internal/service/antigravity_credits_overages.go
Normal file
234
backend/internal/service/antigravity_credits_overages.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
|
||||||
|
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
|
||||||
|
creditsExhaustedKey = "AICredits"
|
||||||
|
creditsExhaustedDuration = 5 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type antigravity429Category string
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravity429Unknown antigravity429Category = "unknown"
|
||||||
|
antigravity429RateLimited antigravity429Category = "rate_limited"
|
||||||
|
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
antigravityQuotaExhaustedKeywords = []string{
|
||||||
|
"quota_exhausted",
|
||||||
|
"quota exhausted",
|
||||||
|
}
|
||||||
|
|
||||||
|
creditsExhaustedKeywords = []string{
|
||||||
|
"google_one_ai",
|
||||||
|
"insufficient credit",
|
||||||
|
"insufficient credits",
|
||||||
|
"not enough credit",
|
||||||
|
"not enough credits",
|
||||||
|
"credit exhausted",
|
||||||
|
"credits exhausted",
|
||||||
|
"credit balance",
|
||||||
|
"minimumcreditamountforusage",
|
||||||
|
"minimum credit amount for usage",
|
||||||
|
"minimum credit",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
|
||||||
|
func (a *Account) isCreditsExhausted() bool {
|
||||||
|
if a == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.isRateLimitActiveForKey(creditsExhaustedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
|
||||||
|
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
|
||||||
|
if account == nil || account.ID == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resetAt := time.Now().Add(creditsExhaustedDuration)
|
||||||
|
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
|
||||||
|
account.ID, resetAt.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
|
||||||
|
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
|
||||||
|
if account == nil || account.ID == 0 || account.Extra == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(rawLimits, creditsExhaustedKey)
|
||||||
|
account.Extra[modelRateLimitsKey] = rawLimits
|
||||||
|
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||||
|
modelRateLimitsKey: rawLimits,
|
||||||
|
}); err != nil {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
|
||||||
|
func classifyAntigravity429(body []byte) antigravity429Category {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return antigravity429Unknown
|
||||||
|
}
|
||||||
|
lowerBody := strings.ToLower(string(body))
|
||||||
|
for _, keyword := range antigravityQuotaExhaustedKeywords {
|
||||||
|
if strings.Contains(lowerBody, keyword) {
|
||||||
|
return antigravity429QuotaExhausted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
|
||||||
|
return antigravity429RateLimited
|
||||||
|
}
|
||||||
|
return antigravity429Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
|
||||||
|
func injectEnabledCreditTypes(body []byte) []byte {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
|
||||||
|
result, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
|
||||||
|
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
|
||||||
|
modelKey := strings.TrimSpace(upstreamModelName)
|
||||||
|
if modelKey != "" {
|
||||||
|
return modelKey
|
||||||
|
}
|
||||||
|
if account == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||||
|
if strings.TrimSpace(modelKey) != "" {
|
||||||
|
return modelKey
|
||||||
|
}
|
||||||
|
return resolveAntigravityModelKey(requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
||||||
|
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
||||||
|
if reqErr != nil || resp == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isURLLevelRateLimit(respBody) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
bodyLower := strings.ToLower(string(respBody))
|
||||||
|
for _, keyword := range creditsExhaustedKeywords {
|
||||||
|
if strings.Contains(bodyLower, keyword) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type creditsOveragesRetryResult struct {
|
||||||
|
handled bool
|
||||||
|
resp *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
|
||||||
|
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
|
||||||
|
p antigravityRetryLoopParams,
|
||||||
|
baseURL string,
|
||||||
|
modelName string,
|
||||||
|
waitDuration time.Duration,
|
||||||
|
originalStatusCode int,
|
||||||
|
respBody []byte,
|
||||||
|
) *creditsOveragesRetryResult {
|
||||||
|
creditsBody := injectEnabledCreditTypes(p.body)
|
||||||
|
if creditsBody == nil {
|
||||||
|
return &creditsOveragesRetryResult{handled: false}
|
||||||
|
}
|
||||||
|
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
|
||||||
|
p.prefix, modelKey, p.account.ID)
|
||||||
|
|
||||||
|
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
|
||||||
|
p.prefix, modelKey, p.account.ID, err)
|
||||||
|
return &creditsOveragesRetryResult{handled: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||||
|
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
|
||||||
|
s.clearCreditsExhausted(p.ctx, p.account)
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
|
||||||
|
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
|
||||||
|
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
|
||||||
|
return &creditsOveragesRetryResult{handled: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
|
||||||
|
ctx context.Context,
|
||||||
|
prefix string,
|
||||||
|
modelKey string,
|
||||||
|
account *Account,
|
||||||
|
creditsResp *http.Response,
|
||||||
|
reqErr error,
|
||||||
|
) {
|
||||||
|
var creditsRespBody []byte
|
||||||
|
creditsStatusCode := 0
|
||||||
|
if creditsResp != nil {
|
||||||
|
creditsStatusCode = creditsResp.StatusCode
|
||||||
|
if creditsResp.Body != nil {
|
||||||
|
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
|
||||||
|
_ = creditsResp.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
|
||||||
|
s.setCreditsExhausted(ctx, account)
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
|
||||||
|
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if account != nil {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
|
||||||
|
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
|
||||||
|
}
|
||||||
|
}
|
||||||
538
backend/internal/service/antigravity_credits_overages_test.go
Normal file
538
backend/internal/service/antigravity_credits_overages_test.go
Normal file
@@ -0,0 +1,538 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClassifyAntigravity429(t *testing.T) {
|
||||||
|
t.Run("明确配额耗尽", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||||
|
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("结构化限流", func(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("未知429", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"message":"too many requests"}}`)
|
||||||
|
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
|
||||||
|
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, account.isCreditsExhausted())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
creditsExhaustedKey: map[string]any{
|
||||||
|
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, account.isCreditsExhausted())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 3,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
creditsExhaustedKey: map[string]any{
|
||||||
|
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, account.isCreditsExhausted())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{successResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 101,
|
||||||
|
Name: "acc-101",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
},
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-opus-4-6": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
requestedModel: "claude-opus-4-6",
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.NotNil(t, result.resp)
|
||||||
|
require.Nil(t, result.switchError)
|
||||||
|
require.Len(t, upstream.requestBodies, 1)
|
||||||
|
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||||
|
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{successResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 102,
|
||||||
|
Name: "acc-102",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.NotNil(t, result.resp)
|
||||||
|
require.Len(t, upstream.requestBodies, 1)
|
||||||
|
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||||
|
require.Empty(t, repo.extraUpdateCalls)
|
||||||
|
require.Empty(t, repo.modelRateLimitCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
|
||||||
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
|
oldAvailability := antigravity.DefaultURLAvailability
|
||||||
|
defer func() {
|
||||||
|
antigravity.BaseURLs = oldBaseURLs
|
||||||
|
antigravity.DefaultURLAvailability = oldAvailability
|
||||||
|
}()
|
||||||
|
|
||||||
|
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||||
|
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||||
|
|
||||||
|
upstream := &queuedHTTPUpstreamStub{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Name: "acc-103",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, upstream.requestBodies, 1)
|
||||||
|
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
|
||||||
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
|
oldAvailability := antigravity.DefaultURLAvailability
|
||||||
|
defer func() {
|
||||||
|
antigravity.BaseURLs = oldBaseURLs
|
||||||
|
antigravity.DefaultURLAvailability = oldAvailability
|
||||||
|
}()
|
||||||
|
|
||||||
|
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||||
|
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||||
|
|
||||||
|
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
|
||||||
|
account := &Account{
|
||||||
|
ID: 104,
|
||||||
|
Name: "acc-104",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
creditsExhaustedKey: map[string]any{
|
||||||
|
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||||
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 模型限流 + 积分耗尽 → 应触发切号错误
|
||||||
|
require.Error(t, err)
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
require.ErrorAs(t, err, &switchErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
|
||||||
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
|
oldAvailability := antigravity.DefaultURLAvailability
|
||||||
|
defer func() {
|
||||||
|
antigravity.BaseURLs = oldBaseURLs
|
||||||
|
antigravity.DefaultURLAvailability = oldAvailability
|
||||||
|
}()
|
||||||
|
|
||||||
|
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||||
|
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||||
|
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
upstream := &queuedHTTPUpstreamStub{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
|
||||||
|
account := &Account{
|
||||||
|
ID: 105,
|
||||||
|
Name: "acc-105",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"allow_overages": true,
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
|
||||||
|
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldMarkCreditsExhausted(t *testing.T) {
|
||||||
|
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("5xx 响应不标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusInternalServerError}
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("URL 级限流不标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||||
|
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("结构化限流不标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||||
|
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("含 credits 关键词时标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||||
|
for _, keyword := range []string{
|
||||||
|
"Insufficient GOOGLE_ONE_AI credits",
|
||||||
|
"insufficient credit balance",
|
||||||
|
"not enough credits for this request",
|
||||||
|
"Credits exhausted",
|
||||||
|
"minimumCreditAmountForUsage requirement not met",
|
||||||
|
} {
|
||||||
|
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
|
||||||
|
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
|
||||||
|
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||||
|
body := []byte(`{"error":{"message":"permission denied"}}`)
|
||||||
|
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectEnabledCreditTypes(t *testing.T) {
|
||||||
|
t.Run("正常 JSON 注入成功", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
|
||||||
|
result := injectEnabledCreditTypes(body)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Contains(t, string(result), `"enabledCreditTypes"`)
|
||||||
|
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
|
||||||
|
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("空 body 返回 nil", func(t *testing.T) {
|
||||||
|
require.Nil(t, injectEnabledCreditTypes([]byte{}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
|
||||||
|
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
|
||||||
|
result := injectEnabledCreditTypes(body)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||||
|
require.NotContains(t, string(result), `OLD`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCreditsExhausted(t *testing.T) {
|
||||||
|
t.Run("account 为 nil 不操作", func(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
svc.clearCreditsExhausted(context.Background(), nil)
|
||||||
|
require.Empty(t, repo.extraUpdateCalls)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
|
||||||
|
require.Empty(t, repo.extraUpdateCalls)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||||
|
ID: 1,
|
||||||
|
Extra: map[string]any{"some_key": "value"},
|
||||||
|
})
|
||||||
|
require.Empty(t, repo.extraUpdateCalls)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无 AICredits key 不操作", func(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||||
|
ID: 1,
|
||||||
|
Extra: map[string]any{
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Empty(t, repo.extraUpdateCalls)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Extra: map[string]any{
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||||
|
},
|
||||||
|
creditsExhaustedKey: map[string]any{
|
||||||
|
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||||
|
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc.clearCreditsExhausted(context.Background(), account)
|
||||||
|
require.Len(t, repo.extraUpdateCalls, 1)
|
||||||
|
// AICredits key 应被删除
|
||||||
|
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||||
|
_, exists := rawLimits[creditsExhaustedKey]
|
||||||
|
require.False(t, exists, "AICredits key 应被删除")
|
||||||
|
// 普通模型限流应保留
|
||||||
|
_, exists = rawLimits["claude-sonnet-4-5"]
|
||||||
|
require.True(t, exists, "普通模型限流应保留")
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -188,9 +188,29 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
return &smartRetryResult{action: smartRetryActionContinueURL}
|
return &smartRetryResult{action: smartRetryActionContinueURL}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
category := antigravity429Unknown
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
category = classifyAntigravity429(respBody)
|
||||||
|
}
|
||||||
|
|
||||||
// 判断是否触发智能重试
|
// 判断是否触发智能重试
|
||||||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||||||
|
|
||||||
|
// AI Credits 超量请求:
|
||||||
|
// 仅在上游明确返回免费配额耗尽时才允许切换到 credits。
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests &&
|
||||||
|
category == antigravity429QuotaExhausted &&
|
||||||
|
p.account.IsOveragesEnabled() &&
|
||||||
|
!p.account.isCreditsExhausted() {
|
||||||
|
result := s.attemptCreditsOveragesRetry(p, baseURL, modelName, waitDuration, resp.StatusCode, respBody)
|
||||||
|
if result.handled && result.resp != nil {
|
||||||
|
return &smartRetryResult{
|
||||||
|
action: smartRetryActionBreakWithResp,
|
||||||
|
resp: result.resp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||||||
if shouldRateLimitModel {
|
if shouldRateLimitModel {
|
||||||
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
|
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
|
||||||
@@ -532,14 +552,31 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
|||||||
|
|
||||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||||
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||||
|
// 预检查:模型限流 + overages 启用 + 积分未耗尽 → 直接注入 AI Credits
|
||||||
|
overagesInjected := false
|
||||||
|
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
|
||||||
|
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
|
||||||
|
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
|
||||||
|
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
|
||||||
|
p.body = creditsBody
|
||||||
|
overagesInjected = true
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
|
||||||
|
p.prefix, p.requestedModel, p.account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 预检查:如果账号已限流,直接返回切换信号
|
// 预检查:如果账号已限流,直接返回切换信号
|
||||||
if p.requestedModel != "" {
|
if p.requestedModel != "" {
|
||||||
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
||||||
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
// 已注入积分的请求不再受普通模型限流预检查阻断。
|
||||||
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
if overagesInjected {
|
||||||
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: credits_injected_ignore_rate_limit remaining=%v model=%s account=%d",
|
||||||
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||||
if isSingleAccountRetry(p.ctx) {
|
} else if isSingleAccountRetry(p.ctx) {
|
||||||
|
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
||||||
|
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
||||||
|
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||||
|
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
||||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||||
} else {
|
} else {
|
||||||
@@ -631,6 +668,15 @@ urlFallbackLoop:
|
|||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if overagesInjected && shouldMarkCreditsExhausted(resp, respBody, nil) {
|
||||||
|
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, "", p.requestedModel)
|
||||||
|
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||||
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
@@ -884,7 +930,7 @@ func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParam
|
|||||||
case ErrorPolicyTempUnscheduled:
|
case ErrorPolicyTempUnscheduled:
|
||||||
slog.Info("temp_unschedulable_matched",
|
slog.Info("temp_unschedulable_matched",
|
||||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||||
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, IsStickySession: p.isStickySession}
|
||||||
}
|
}
|
||||||
return false, statusCode, nil
|
return false, statusCode, nil
|
||||||
}
|
}
|
||||||
@@ -955,8 +1001,9 @@ type TestConnectionResult struct {
|
|||||||
MappedModel string // 实际使用的模型
|
MappedModel string // 实际使用的模型
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
// TestConnection 测试 Antigravity 账号连接。
|
||||||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
// 复用 antigravityRetryLoop 的完整重试 / credits overages / 智能重试逻辑,
|
||||||
|
// 与真实调度行为一致。差异:不做账号切换(测试指定账号)、不记录 ops 错误。
|
||||||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||||
|
|
||||||
// 获取 token
|
// 获取 token
|
||||||
@@ -980,10 +1027,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
// 构建请求体
|
// 构建请求体
|
||||||
var requestBody []byte
|
var requestBody []byte
|
||||||
if strings.HasPrefix(modelID, "gemini-") {
|
if strings.HasPrefix(modelID, "gemini-") {
|
||||||
// Gemini 模型:直接使用 Gemini 格式
|
|
||||||
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
||||||
} else {
|
} else {
|
||||||
// Claude 模型:使用协议转换
|
|
||||||
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -996,64 +1041,63 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := resolveAntigravityForwardBaseURL()
|
// 复用 antigravityRetryLoop:完整的重试 / credits overages / 智能重试
|
||||||
if baseURL == "" {
|
prefix := fmt.Sprintf("[antigravity-Test] account=%d(%s)", account.ID, account.Name)
|
||||||
return nil, errors.New("no antigravity forward base url configured")
|
p := antigravityRetryLoopParams{
|
||||||
}
|
ctx: ctx,
|
||||||
availableURLs := []string{baseURL}
|
prefix: prefix,
|
||||||
|
account: account,
|
||||||
var lastErr error
|
proxyURL: proxyURL,
|
||||||
for urlIdx, baseURL := range availableURLs {
|
accessToken: accessToken,
|
||||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
action: "streamGenerateContent",
|
||||||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
body: requestBody,
|
||||||
if err != nil {
|
c: nil, // 无 gin.Context → 跳过 ops 追踪
|
||||||
lastErr = err
|
httpUpstream: s.httpUpstream,
|
||||||
continue
|
settingService: s.settingService,
|
||||||
}
|
accountRepo: s.accountRepo,
|
||||||
|
requestedModel: modelID,
|
||||||
// 调试日志:Test 请求信息
|
handleError: testConnectionHandleError,
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
lastErr = fmt.Errorf("请求失败: %w", err)
|
|
||||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, lastErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// 读取响应
|
|
||||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否需要 URL 降级
|
|
||||||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析流式响应,提取文本
|
|
||||||
text := extractTextFromSSEResponse(respBody)
|
|
||||||
|
|
||||||
// 标记成功的 URL,下次优先使用
|
|
||||||
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
|
|
||||||
return &TestConnectionResult{
|
|
||||||
Text: text,
|
|
||||||
MappedModel: mappedModel,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, lastErr
|
result, err := s.antigravityRetryLoop(p)
|
||||||
|
if err != nil {
|
||||||
|
// AccountSwitchError → 测试时不切换账号,返回友好提示
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
if errors.As(err, &switchErr) {
|
||||||
|
return nil, fmt.Errorf("该账号模型 %s 当前限流中,请稍后重试", switchErr.RateLimitedModel)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil || result.resp == nil {
|
||||||
|
return nil, errors.New("upstream returned empty response")
|
||||||
|
}
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(io.LimitReader(result.resp.Body, 2<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("API 返回 %d: %s", result.resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
text := extractTextFromSSEResponse(respBody)
|
||||||
|
return &TestConnectionResult{Text: text, MappedModel: mappedModel}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnectionHandleError 是 TestConnection 使用的轻量 handleError 回调。
|
||||||
|
// 仅记录日志,不做 ops 错误追踪或粘性会话清除。
|
||||||
|
func testConnectionHandleError(
|
||||||
|
_ context.Context, prefix string, account *Account,
|
||||||
|
statusCode int, _ http.Header, body []byte,
|
||||||
|
requestedModel string, _ int64, _ string, _ bool,
|
||||||
|
) *handleModelRateLimitResult {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway",
|
||||||
|
"%s test_handle_error status=%d model=%s account=%d body=%s",
|
||||||
|
prefix, statusCode, requestedModel, account.ID, truncateForLog(body, 200))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||||
@@ -3033,6 +3077,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||||
@@ -3065,6 +3125,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
return nil, ev.err
|
return nil, ev.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
line := ev.line
|
line := ev.line
|
||||||
trimmed := strings.TrimRight(line, "\r\n")
|
trimmed := strings.TrimRight(line, "\r\n")
|
||||||
if strings.HasPrefix(trimmed, "data:") {
|
if strings.HasPrefix(trimmed, "data:") {
|
||||||
@@ -3124,6 +3186,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if cw.Disconnected() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping/keepalive:保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if !cw.Fprintf(":\n\n") {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity gemini), continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3849,6 +3924,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||||
@@ -3901,6 +3992,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
// 处理 SSE 行,转换为 Claude 格式
|
// 处理 SSE 行,转换为 Claude 格式
|
||||||
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
||||||
if len(claudeEvents) > 0 {
|
if len(claudeEvents) > 0 {
|
||||||
@@ -3923,6 +4016,20 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if cw.Disconnected() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||||
|
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity claude), continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4253,6 +4360,22 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
flusher, _ := c.Writer.(http.Flusher)
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
||||||
|
|
||||||
@@ -4270,6 +4393,8 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
|||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
line := ev.line
|
line := ev.line
|
||||||
|
|
||||||
// 记录首 token 时间
|
// 记录首 token 时间
|
||||||
@@ -4295,6 +4420,20 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
|||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
|
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if cw.Disconnected() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||||
|
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity upstream), continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,11 +78,11 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
|
||||||
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||||
|
|
||||||
// 转换为 UsageInfo
|
// 转换为 UsageInfo
|
||||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
|
||||||
|
|
||||||
return &QuotaResult{
|
return &QuotaResult{
|
||||||
UsageInfo: usageInfo,
|
UsageInfo: usageInfo,
|
||||||
@@ -90,20 +90,21 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。
|
||||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
// 同时返回 LoadCodeAssistResponse,以便提取 AI Credits 余额。
|
||||||
|
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
|
||||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||||
return "", ""
|
return "", "", nil
|
||||||
}
|
}
|
||||||
if loadResp == nil {
|
if loadResp == nil {
|
||||||
return "", ""
|
return "", "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||||
normalized = normalizeTier(raw)
|
normalized = normalizeTier(raw)
|
||||||
return raw, normalized
|
return raw, normalized, loadResp
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||||
@@ -124,8 +125,8 @@ func normalizeTier(raw string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
// buildUsageInfo 将 API 响应转换为 UsageInfo。
|
||||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
info := &UsageInfo{
|
info := &UsageInfo{
|
||||||
UpdatedAt: &now,
|
UpdatedAt: &now,
|
||||||
@@ -190,6 +191,16 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if loadResp != nil {
|
||||||
|
for _, credit := range loadResp.GetAvailableCredits() {
|
||||||
|
info.AICredits = append(info.AICredits, AICredit{
|
||||||
|
CreditType: credit.CreditType,
|
||||||
|
Amount: credit.GetAmount(),
|
||||||
|
MinimumBalance: credit.GetMinimumAmount(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
|
||||||
|
|
||||||
// 基本字段
|
// 基本字段
|
||||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||||
@@ -141,7 +141,7 @@ func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.Len(t, info.ModelForwardingRules, 2)
|
require.Len(t, info.ModelForwardingRules, 2)
|
||||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||||
@@ -159,7 +159,7 @@ func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||||
}
|
}
|
||||||
@@ -171,7 +171,7 @@ func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
|||||||
Models: map[string]antigravity.ModelInfo{},
|
Models: map[string]antigravity.ModelInfo{},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.NotNil(t, info)
|
require.NotNil(t, info)
|
||||||
require.NotNil(t, info.AntigravityQuota)
|
require.NotNil(t, info.AntigravityQuota)
|
||||||
@@ -193,7 +193,7 @@ func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.NotNil(t, info)
|
require.NotNil(t, info)
|
||||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||||
@@ -222,7 +222,7 @@ func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||||
@@ -251,7 +251,7 @@ func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.NotNil(t, info.FiveHour)
|
require.NotNil(t, info.FiveHour)
|
||||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||||
@@ -277,7 +277,7 @@ func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.NotNil(t, info.FiveHour)
|
require.NotNil(t, info.FiveHour)
|
||||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||||
@@ -298,7 +298,7 @@ func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||||
}
|
}
|
||||||
@@ -317,7 +317,7 @@ func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
require.NotNil(t, info.FiveHour)
|
require.NotNil(t, info.FiveHour)
|
||||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||||
@@ -338,7 +338,7 @@ func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||||
require.NotNil(t, quota)
|
require.NotNil(t, quota)
|
||||||
@@ -358,13 +358,38 @@ func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||||
|
|
||||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||||
require.NotNil(t, quota)
|
require.NotNil(t, quota)
|
||||||
require.Equal(t, 0, quota.Utilization)
|
require.Equal(t, 0, quota.Utilization)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_AICredits(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{},
|
||||||
|
}
|
||||||
|
loadResp := &antigravity.LoadCodeAssistResponse{
|
||||||
|
PaidTier: &antigravity.PaidTierInfo{
|
||||||
|
ID: "g1-pro-tier",
|
||||||
|
AvailableCredits: []antigravity.AvailableCredit{
|
||||||
|
{
|
||||||
|
CreditType: "GOOGLE_ONE_AI",
|
||||||
|
CreditAmount: "25",
|
||||||
|
MinimumCreditAmountForUsage: "5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
|
||||||
|
|
||||||
|
require.Len(t, info.AICredits, 1)
|
||||||
|
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
|
||||||
|
require.Equal(t, 25.0, info.AICredits[0].Amount)
|
||||||
|
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
|
||||||
|
}
|
||||||
|
|
||||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||||
|
|||||||
@@ -32,6 +32,10 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||||
|
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
|
||||||
|
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -76,10 +76,16 @@ type modelRateLimitCall struct {
|
|||||||
resetAt time.Time
|
resetAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type extraUpdateCall struct {
|
||||||
|
accountID int64
|
||||||
|
updates map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
type stubAntigravityAccountRepo struct {
|
type stubAntigravityAccountRepo struct {
|
||||||
AccountRepository
|
AccountRepository
|
||||||
rateCalls []rateLimitCall
|
rateCalls []rateLimitCall
|
||||||
modelRateLimitCalls []modelRateLimitCall
|
modelRateLimitCalls []modelRateLimitCall
|
||||||
|
extraUpdateCalls []extraUpdateCall
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
@@ -92,6 +98,11 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAntigravityAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||||
|
s.extraUpdateCalls = append(s.extraUpdateCalls, extraUpdateCall{accountID: id, updates: updates})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
|
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
|
||||||
t.Setenv(antigravityForwardBaseURLEnv, "")
|
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||||
|
|
||||||
|
|||||||
@@ -260,14 +260,15 @@ func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *test
|
|||||||
|
|
||||||
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
||||||
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
||||||
|
// 使用 RATE_LIMIT_EXCEEDED(走 1 次智能重试),避免 MODEL_CAPACITY_EXHAUSTED 的 60 次重试导致测试超时
|
||||||
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
||||||
// 智能重试也返回 503
|
// 智能重试也返回 503
|
||||||
failRespBody := `{
|
failRespBody := `{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 503,
|
||||||
"status": "UNAVAILABLE",
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
"details": [
|
"details": [
|
||||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -278,8 +279,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
|
|||||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||||
}
|
}
|
||||||
upstream := &mockSmartRetryUpstream{
|
upstream := &mockSmartRetryUpstream{
|
||||||
responses: []*http.Response{failResp},
|
responses: []*http.Response{failResp},
|
||||||
errors: []error{nil},
|
errors: []error{nil},
|
||||||
|
repeatLast: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := &stubAntigravityAccountRepo{}
|
repo := &stubAntigravityAccountRepo{}
|
||||||
@@ -294,9 +296,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
|
|||||||
respBody := []byte(`{
|
respBody := []byte(`{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 503,
|
||||||
"status": "UNAVAILABLE",
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
"details": [
|
"details": [
|
||||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -569,8 +571,9 @@ func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
|
|||||||
|
|
||||||
svc := &AntigravityGatewayService{}
|
svc := &AntigravityGatewayService{}
|
||||||
|
|
||||||
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
|
// waitDuration=0 会被 clamp 到 antigravitySmartRetryMinWait=1s。
|
||||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
|
// 首次重试即成功(200),总耗时 ~1s。
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 0, "gemini-3-pro")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
require.NotNil(t, result.resp)
|
require.NotNil(t, result.resp)
|
||||||
|
|||||||
@@ -32,20 +32,65 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID
|
|||||||
|
|
||||||
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
|
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
|
||||||
type mockSmartRetryUpstream struct {
|
type mockSmartRetryUpstream struct {
|
||||||
responses []*http.Response
|
responses []*http.Response
|
||||||
errors []error
|
responseBodies [][]byte // 缓存的 response body 字节(用于 repeatLast 重建)
|
||||||
callIdx int
|
errors []error
|
||||||
calls []string
|
callIdx int
|
||||||
|
calls []string
|
||||||
|
requestBodies [][]byte
|
||||||
|
repeatLast bool // 超出范围时重复最后一个响应
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||||
idx := m.callIdx
|
idx := m.callIdx
|
||||||
m.calls = append(m.calls, req.URL.String())
|
m.calls = append(m.calls, req.URL.String())
|
||||||
m.callIdx++
|
if req != nil && req.Body != nil {
|
||||||
if idx < len(m.responses) {
|
body, _ := io.ReadAll(req.Body)
|
||||||
return m.responses[idx], m.errors[idx]
|
m.requestBodies = append(m.requestBodies, body)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
} else {
|
||||||
|
m.requestBodies = append(m.requestBodies, nil)
|
||||||
}
|
}
|
||||||
return nil, nil
|
m.callIdx++
|
||||||
|
|
||||||
|
// 确定使用哪个索引
|
||||||
|
respIdx := idx
|
||||||
|
if respIdx >= len(m.responses) {
|
||||||
|
if !m.repeatLast || len(m.responses) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
respIdx = len(m.responses) - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := m.responses[respIdx]
|
||||||
|
respErr := m.errors[respIdx]
|
||||||
|
if resp == nil {
|
||||||
|
return nil, respErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 首次调用时缓存 body 字节
|
||||||
|
if respIdx >= len(m.responseBodies) {
|
||||||
|
for len(m.responseBodies) <= respIdx {
|
||||||
|
m.responseBodies = append(m.responseBodies, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if m.responseBodies[respIdx] == nil && resp.Body != nil {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
m.responseBodies[respIdx] = bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用缓存的 body 字节重建新的 reader
|
||||||
|
var body io.ReadCloser
|
||||||
|
if m.responseBodies[respIdx] != nil {
|
||||||
|
body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
}, respErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,15 +16,18 @@ const (
|
|||||||
antigravityBackfillCooldown = 5 * time.Minute
|
antigravityBackfillCooldown = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// AntigravityTokenCache token cache interface.
|
||||||
type AntigravityTokenCache = GeminiTokenCache
|
type AntigravityTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
|
// AntigravityTokenProvider manages access_token for antigravity accounts.
|
||||||
type AntigravityTokenProvider struct {
|
type AntigravityTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache AntigravityTokenCache
|
tokenCache AntigravityTokenCache
|
||||||
antigravityOAuthService *AntigravityOAuthService
|
antigravityOAuthService *AntigravityOAuthService
|
||||||
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
backfillCooldown sync.Map // key: accountID -> last attempt time
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAntigravityTokenProvider(
|
func NewAntigravityTokenProvider(
|
||||||
@@ -37,10 +39,22 @@ func NewAntigravityTokenProvider(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
antigravityOAuthService: antigravityOAuthService,
|
antigravityOAuthService: antigravityOAuthService,
|
||||||
|
refreshPolicy: AntigravityProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取有效的 access_token
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken returns a valid access_token.
|
||||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
if account.Platform != PlatformAntigravity {
|
if account.Platform != PlatformAntigravity {
|
||||||
return "", errors.New("not an antigravity account")
|
return "", errors.New("not an antigravity account")
|
||||||
}
|
}
|
||||||
// upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程
|
|
||||||
|
// upstream accounts use static api_key and never refresh oauth token.
|
||||||
if account.Type == AccountTypeUpstream {
|
if account.Type == AccountTypeUpstream {
|
||||||
apiKey := account.GetCredential("api_key")
|
apiKey := account.GetCredential("api_key")
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
@@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
|
|
||||||
cacheKey := AntigravityTokenCacheKey(account)
|
cacheKey := AntigravityTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 如果即将过期则刷新
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||||
if needsRefresh && p.tokenCache != nil {
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
|
||||||
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if result.LockHeld {
|
||||||
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// default policy: continue with existing token.
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
if err == nil && locked {
|
if err == nil && locked {
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从数据库获取最新账户信息
|
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
|
||||||
if err == nil && fresh != nil {
|
|
||||||
account = fresh
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
|
||||||
if p.antigravityOAuthService == nil {
|
|
||||||
return "", errors.New("antigravity oauth service not configured")
|
|
||||||
}
|
|
||||||
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
p.mergeCredentials(account, tokenInfo)
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
// Backfill project_id online when missing, with cooldown to avoid hammering.
|
||||||
// "Invalid project resource name projects/"。
|
|
||||||
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
|
||||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||||
if p.shouldAttemptBackfill(account.ID) {
|
if p.shouldAttemptBackfill(account.ID) {
|
||||||
p.markBackfillAttempted(account.ID)
|
p.markBackfillAttempted(account.ID)
|
||||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||||
account.Credentials["project_id"] = projectID
|
account.Credentials["project_id"] = projectID
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", updateErr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetCredential("access_token")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if expiresAt != nil {
|
if expiresAt != nil {
|
||||||
@@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
// shouldAttemptBackfill checks backfill cooldown.
|
||||||
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
|
||||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
|
||||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||||
if lastAttempt, ok := v.(time.Time); ok {
|
if lastAttempt, ok := v.(time.Time); ok {
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键
|
||||||
|
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return AntigravityTokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
// CanRefresh 检查是否可以刷新此账户
|
// CanRefresh 检查是否可以刷新此账户
|
||||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||||
@@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
|||||||
|
|
||||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||||
for k, v := range account.Credentials {
|
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||||
|
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
|
||||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||||
return windowStart != nil && time.Since(*windowStart) >= duration
|
return windowStart == nil || time.Since(*windowStart) >= duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
|
|||||||
want bool
|
want bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "nil window start",
|
name: "nil window start (treated as expired)",
|
||||||
start: nil,
|
start: nil,
|
||||||
duration: RateLimitWindow5h,
|
duration: RateLimitWindow5h,
|
||||||
want: false,
|
want: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "active window (started 1h ago, 5h window)",
|
name: "active window (started 1h ago, 5h window)",
|
||||||
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
|||||||
want7d: 0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil window starts return raw usage",
|
name: "nil window starts return 0 (stale usage reset)",
|
||||||
key: APIKey{
|
key: APIKey{
|
||||||
Usage5h: 5.0,
|
Usage5h: 5.0,
|
||||||
Usage1d: 10.0,
|
Usage1d: 10.0,
|
||||||
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
|||||||
Window1dStart: nil,
|
Window1dStart: nil,
|
||||||
Window7dStart: nil,
|
Window7dStart: nil,
|
||||||
},
|
},
|
||||||
want5h: 5.0,
|
want5h: 0,
|
||||||
want1d: 10.0,
|
want1d: 0,
|
||||||
want7d: 50.0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||||
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
|||||||
},
|
},
|
||||||
want5h: 0,
|
want5h: 0,
|
||||||
want1d: 10.0,
|
want1d: 10.0,
|
||||||
want7d: 50.0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "zero usage with active windows",
|
name: "zero usage with active windows",
|
||||||
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
|||||||
want7d: 0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil window starts return raw usage",
|
name: "nil window starts return 0 (stale usage reset)",
|
||||||
data: APIKeyRateLimitData{
|
data: APIKeyRateLimitData{
|
||||||
Usage5h: 3.0,
|
Usage5h: 3.0,
|
||||||
Usage1d: 8.0,
|
Usage1d: 8.0,
|
||||||
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
|||||||
Window1dStart: nil,
|
Window1dStart: nil,
|
||||||
Window7dStart: nil,
|
Window7dStart: nil,
|
||||||
},
|
},
|
||||||
want5h: 3.0,
|
want5h: 0,
|
||||||
want1d: 8.0,
|
want1d: 0,
|
||||||
want7d: 40.0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ import (
|
|||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -84,17 +86,21 @@ type BackupScheduleConfig struct {
|
|||||||
|
|
||||||
// BackupRecord 备份记录
|
// BackupRecord 备份记录
|
||||||
type BackupRecord struct {
|
type BackupRecord struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Status string `json:"status"` // pending, running, completed, failed
|
Status string `json:"status"` // pending, running, completed, failed
|
||||||
BackupType string `json:"backup_type"` // postgres
|
BackupType string `json:"backup_type"` // postgres
|
||||||
FileName string `json:"file_name"`
|
FileName string `json:"file_name"`
|
||||||
S3Key string `json:"s3_key"`
|
S3Key string `json:"s3_key"`
|
||||||
SizeBytes int64 `json:"size_bytes"`
|
SizeBytes int64 `json:"size_bytes"`
|
||||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||||
ErrorMsg string `json:"error_message,omitempty"`
|
ErrorMsg string `json:"error_message,omitempty"`
|
||||||
StartedAt string `json:"started_at"`
|
StartedAt string `json:"started_at"`
|
||||||
FinishedAt string `json:"finished_at,omitempty"`
|
FinishedAt string `json:"finished_at,omitempty"`
|
||||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||||
|
Progress string `json:"progress,omitempty"` // "dumping", "uploading", ""
|
||||||
|
RestoreStatus string `json:"restore_status,omitempty"` // "", "running", "completed", "failed"
|
||||||
|
RestoreError string `json:"restore_error,omitempty"`
|
||||||
|
RestoredAt string `json:"restored_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackupService 数据库备份恢复服务
|
// BackupService 数据库备份恢复服务
|
||||||
@@ -105,17 +111,24 @@ type BackupService struct {
|
|||||||
storeFactory BackupObjectStoreFactory
|
storeFactory BackupObjectStoreFactory
|
||||||
dumper DBDumper
|
dumper DBDumper
|
||||||
|
|
||||||
mu sync.Mutex
|
opMu sync.Mutex // 保护 backingUp/restoring 标志
|
||||||
store BackupObjectStore
|
|
||||||
s3Cfg *BackupS3Config
|
|
||||||
backingUp bool
|
backingUp bool
|
||||||
restoring bool
|
restoring bool
|
||||||
|
|
||||||
|
storeMu sync.Mutex // 保护 store/s3Cfg 缓存
|
||||||
|
store BackupObjectStore
|
||||||
|
s3Cfg *BackupS3Config
|
||||||
|
|
||||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||||
|
|
||||||
cronMu sync.Mutex
|
cronMu sync.Mutex
|
||||||
cronSched *cron.Cron
|
cronSched *cron.Cron
|
||||||
cronEntryID cron.EntryID
|
cronEntryID cron.EntryID
|
||||||
|
|
||||||
|
wg sync.WaitGroup // 追踪活跃的备份/恢复 goroutine
|
||||||
|
shuttingDown atomic.Bool // 阻止新备份启动
|
||||||
|
bgCtx context.Context // 所有后台操作的 parent context
|
||||||
|
bgCancel context.CancelFunc // 取消所有活跃后台操作
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackupService(
|
func NewBackupService(
|
||||||
@@ -125,20 +138,26 @@ func NewBackupService(
|
|||||||
storeFactory BackupObjectStoreFactory,
|
storeFactory BackupObjectStoreFactory,
|
||||||
dumper DBDumper,
|
dumper DBDumper,
|
||||||
) *BackupService {
|
) *BackupService {
|
||||||
|
bgCtx, bgCancel := context.WithCancel(context.Background())
|
||||||
return &BackupService{
|
return &BackupService{
|
||||||
settingRepo: settingRepo,
|
settingRepo: settingRepo,
|
||||||
dbCfg: &cfg.Database,
|
dbCfg: &cfg.Database,
|
||||||
encryptor: encryptor,
|
encryptor: encryptor,
|
||||||
storeFactory: storeFactory,
|
storeFactory: storeFactory,
|
||||||
dumper: dumper,
|
dumper: dumper,
|
||||||
|
bgCtx: bgCtx,
|
||||||
|
bgCancel: bgCancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start 启动定时备份调度器
|
// Start 启动定时备份调度器并清理孤立记录
|
||||||
func (s *BackupService) Start() {
|
func (s *BackupService) Start() {
|
||||||
s.cronSched = cron.New()
|
s.cronSched = cron.New()
|
||||||
s.cronSched.Start()
|
s.cronSched.Start()
|
||||||
|
|
||||||
|
// 清理重启后孤立的 running 记录
|
||||||
|
s.recoverStaleRecords()
|
||||||
|
|
||||||
// 加载已有的定时配置
|
// 加载已有的定时配置
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -154,13 +173,65 @@ func (s *BackupService) Start() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop 停止定时备份
|
// recoverStaleRecords 启动时将孤立的 running 记录标记为 failed
|
||||||
|
func (s *BackupService) recoverStaleRecords() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
records, err := s.loadRecords(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range records {
|
||||||
|
if records[i].Status == "running" {
|
||||||
|
records[i].Status = "failed"
|
||||||
|
records[i].ErrorMsg = "interrupted by server restart"
|
||||||
|
records[i].Progress = ""
|
||||||
|
records[i].FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(ctx, &records[i])
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] recovered stale running record: %s", records[i].ID)
|
||||||
|
}
|
||||||
|
if records[i].RestoreStatus == "running" {
|
||||||
|
records[i].RestoreStatus = "failed"
|
||||||
|
records[i].RestoreError = "interrupted by server restart"
|
||||||
|
_ = s.saveRecord(ctx, &records[i])
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] recovered stale restoring record: %s", records[i].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止定时备份并等待活跃操作完成
|
||||||
func (s *BackupService) Stop() {
|
func (s *BackupService) Stop() {
|
||||||
|
s.shuttingDown.Store(true)
|
||||||
|
|
||||||
s.cronMu.Lock()
|
s.cronMu.Lock()
|
||||||
defer s.cronMu.Unlock()
|
|
||||||
if s.cronSched != nil {
|
if s.cronSched != nil {
|
||||||
s.cronSched.Stop()
|
s.cronSched.Stop()
|
||||||
}
|
}
|
||||||
|
s.cronMu.Unlock()
|
||||||
|
|
||||||
|
// 等待活跃备份/恢复完成(最多 5 分钟)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] all active operations finished")
|
||||||
|
case <-time.After(5 * time.Minute):
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] shutdown timeout after 5min, cancelling active operations")
|
||||||
|
if s.bgCancel != nil {
|
||||||
|
s.bgCancel() // 取消所有后台操作
|
||||||
|
}
|
||||||
|
// 给 goroutine 时间响应取消并完成清理
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] active operations cancelled and cleaned up")
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] goroutine cleanup timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── S3 配置管理 ───
|
// ─── S3 配置管理 ───
|
||||||
@@ -203,10 +274,10 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清除缓存的 S3 客户端
|
// 清除缓存的 S3 客户端
|
||||||
s.mu.Lock()
|
s.storeMu.Lock()
|
||||||
s.store = nil
|
s.store = nil
|
||||||
s.s3Cfg = nil
|
s.s3Cfg = nil
|
||||||
s.mu.Unlock()
|
s.storeMu.Unlock()
|
||||||
|
|
||||||
cfg.SecretAccessKey = ""
|
cfg.SecretAccessKey = ""
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
@@ -314,7 +385,10 @@ func (s *BackupService) removeCronSchedule() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackupService) runScheduledBackup() {
|
func (s *BackupService) runScheduledBackup() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
s.wg.Add(1)
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// 读取定时备份配置中的过期天数
|
// 读取定时备份配置中的过期天数
|
||||||
@@ -327,7 +401,11 @@ func (s *BackupService) runScheduledBackup() {
|
|||||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
if errors.Is(err, ErrBackupInProgress) {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份跳过: 已有备份正在进行中")
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||||
@@ -346,17 +424,21 @@ func (s *BackupService) runScheduledBackup() {
|
|||||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||||
s.mu.Lock()
|
if s.shuttingDown.Load() {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.opMu.Lock()
|
||||||
if s.backingUp {
|
if s.backingUp {
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
return nil, ErrBackupInProgress
|
return nil, ErrBackupInProgress
|
||||||
}
|
}
|
||||||
s.backingUp = true
|
s.backingUp = true
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
s.mu.Lock()
|
s.opMu.Lock()
|
||||||
s.backingUp = false
|
s.backingUp = false
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s3Cfg, err := s.loadS3Config(ctx)
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
@@ -405,36 +487,47 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
|||||||
|
|
||||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
var gzipErr error
|
gzipDone := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
|
||||||
|
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
gzWriter := gzip.NewWriter(pw)
|
gzWriter := gzip.NewWriter(pw)
|
||||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
var gzErr error
|
||||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
_, gzErr = io.Copy(gzWriter, dumpReader)
|
||||||
gzipErr = closeErr
|
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
|
||||||
|
gzErr = closeErr
|
||||||
}
|
}
|
||||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
|
||||||
gzipErr = closeErr
|
gzErr = closeErr
|
||||||
}
|
}
|
||||||
if gzipErr != nil {
|
if gzErr != nil {
|
||||||
_ = pw.CloseWithError(gzipErr)
|
_ = pw.CloseWithError(gzErr)
|
||||||
} else {
|
} else {
|
||||||
_ = pw.Close()
|
_ = pw.Close()
|
||||||
}
|
}
|
||||||
|
gzipDone <- gzErr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
contentType := "application/gzip"
|
contentType := "application/gzip"
|
||||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
|
||||||
|
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
|
||||||
record.Status = "failed"
|
record.Status = "failed"
|
||||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||||
if gzipErr != nil {
|
if gzErr != nil {
|
||||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
|
||||||
}
|
}
|
||||||
record.ErrorMsg = errMsg
|
record.ErrorMsg = errMsg
|
||||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
_ = s.saveRecord(ctx, record)
|
_ = s.saveRecord(ctx, record)
|
||||||
return record, fmt.Errorf("backup upload: %w", err)
|
return record, fmt.Errorf("backup upload: %w", err)
|
||||||
}
|
}
|
||||||
|
<-gzipDone // 确保 gzip goroutine 已退出
|
||||||
|
|
||||||
record.SizeBytes = sizeBytes
|
record.SizeBytes = sizeBytes
|
||||||
record.Status = "completed"
|
record.Status = "completed"
|
||||||
@@ -446,19 +539,187 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
|||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackup 异步创建备份,立即返回 running 状态的记录
|
||||||
|
func (s *BackupService) StartBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||||
|
if s.shuttingDown.Load() {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.opMu.Lock()
|
||||||
|
if s.backingUp {
|
||||||
|
s.opMu.Unlock()
|
||||||
|
return nil, ErrBackupInProgress
|
||||||
|
}
|
||||||
|
s.backingUp = true
|
||||||
|
s.opMu.Unlock()
|
||||||
|
|
||||||
|
// 初始化阶段出错时自动重置标志
|
||||||
|
launched := false
|
||||||
|
defer func() {
|
||||||
|
if !launched {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.backingUp = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 在返回前加载 S3 配置和创建 store,避免 goroutine 中配置被修改
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||||
|
return nil, ErrBackupS3NotConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init object store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
backupID := uuid.New().String()[:8]
|
||||||
|
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||||
|
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||||
|
|
||||||
|
var expiresAt string
|
||||||
|
if expireDays > 0 {
|
||||||
|
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &BackupRecord{
|
||||||
|
ID: backupID,
|
||||||
|
Status: "running",
|
||||||
|
BackupType: "postgres",
|
||||||
|
FileName: fileName,
|
||||||
|
S3Key: s3Key,
|
||||||
|
TriggeredBy: triggeredBy,
|
||||||
|
StartedAt: now.Format(time.RFC3339),
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Progress: "pending",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.saveRecord(ctx, record); err != nil {
|
||||||
|
return nil, fmt.Errorf("save initial record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
launched = true
|
||||||
|
// 在启动 goroutine 前完成拷贝,避免数据竞争
|
||||||
|
result := *record
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer func() {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.backingUp = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] panic recovered: %v", r)
|
||||||
|
record.Status = "failed"
|
||||||
|
record.ErrorMsg = fmt.Sprintf("internal panic: %v", r)
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s.executeBackup(record, objectStore)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeBackup 后台执行备份(独立于 HTTP context)
|
||||||
|
func (s *BackupService) executeBackup(record *BackupRecord, objectStore BackupObjectStore) {
|
||||||
|
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 阶段1: pg_dump
|
||||||
|
record.Progress = "dumping"
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
|
||||||
|
dumpReader, err := s.dumper.Dump(ctx)
|
||||||
|
if err != nil {
|
||||||
|
record.Status = "failed"
|
||||||
|
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 阶段2: gzip + upload
|
||||||
|
record.Progress = "uploading"
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
gzipDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
|
||||||
|
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
gzWriter := gzip.NewWriter(pw)
|
||||||
|
var gzErr error
|
||||||
|
_, gzErr = io.Copy(gzWriter, dumpReader)
|
||||||
|
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
|
||||||
|
gzErr = closeErr
|
||||||
|
}
|
||||||
|
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
|
||||||
|
gzErr = closeErr
|
||||||
|
}
|
||||||
|
if gzErr != nil {
|
||||||
|
_ = pw.CloseWithError(gzErr)
|
||||||
|
} else {
|
||||||
|
_ = pw.Close()
|
||||||
|
}
|
||||||
|
gzipDone <- gzErr
|
||||||
|
}()
|
||||||
|
|
||||||
|
contentType := "application/gzip"
|
||||||
|
sizeBytes, err := objectStore.Upload(ctx, record.S3Key, pr, contentType)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
|
||||||
|
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
|
||||||
|
record.Status = "failed"
|
||||||
|
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||||
|
if gzErr != nil {
|
||||||
|
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
|
||||||
|
}
|
||||||
|
record.ErrorMsg = errMsg
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
<-gzipDone // 确保 gzip goroutine 已退出
|
||||||
|
|
||||||
|
record.SizeBytes = sizeBytes
|
||||||
|
record.Status = "completed"
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
if err := s.saveRecord(context.Background(), record); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||||
s.mu.Lock()
|
s.opMu.Lock()
|
||||||
if s.restoring {
|
if s.restoring {
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
return ErrRestoreInProgress
|
return ErrRestoreInProgress
|
||||||
}
|
}
|
||||||
s.restoring = true
|
s.restoring = true
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
s.mu.Lock()
|
s.opMu.Lock()
|
||||||
s.restoring = false
|
s.restoring = false
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
record, err := s.GetBackupRecord(ctx, backupID)
|
record, err := s.GetBackupRecord(ctx, backupID)
|
||||||
@@ -500,6 +761,112 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartRestore 异步恢复备份,立即返回
|
||||||
|
func (s *BackupService) StartRestore(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||||
|
if s.shuttingDown.Load() {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.opMu.Lock()
|
||||||
|
if s.restoring {
|
||||||
|
s.opMu.Unlock()
|
||||||
|
return nil, ErrRestoreInProgress
|
||||||
|
}
|
||||||
|
s.restoring = true
|
||||||
|
s.opMu.Unlock()
|
||||||
|
|
||||||
|
// 初始化阶段出错时自动重置标志
|
||||||
|
launched := false
|
||||||
|
defer func() {
|
||||||
|
if !launched {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.restoring = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
record, err := s.GetBackupRecord(ctx, backupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if record.Status != "completed" {
|
||||||
|
return nil, infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||||
|
}
|
||||||
|
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init object store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record.RestoreStatus = "running"
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
|
||||||
|
launched = true
|
||||||
|
result := *record
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer func() {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.restoring = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] restore panic recovered: %v", r)
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("internal panic: %v", r)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s.executeRestore(record, objectStore)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRestore 后台执行恢复
|
||||||
|
func (s *BackupService) executeRestore(record *BackupRecord, objectStore BackupObjectStore) {
|
||||||
|
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
body, err := objectStore.Download(ctx, record.S3Key)
|
||||||
|
if err != nil {
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("S3 download failed: %v", err)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = body.Close() }()
|
||||||
|
|
||||||
|
gzReader, err := gzip.NewReader(body)
|
||||||
|
if err != nil {
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("gzip reader: %v", err)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = gzReader.Close() }()
|
||||||
|
|
||||||
|
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("pg restore: %v", err)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record.RestoreStatus = "completed"
|
||||||
|
record.RestoredAt = time.Now().Format(time.RFC3339)
|
||||||
|
if err := s.saveRecord(context.Background(), record); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 保存恢复记录失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ─── 备份记录管理 ───
|
// ─── 备份记录管理 ───
|
||||||
|
|
||||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||||
@@ -614,8 +981,8 @@ func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||||
s.mu.Lock()
|
s.storeMu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.storeMu.Unlock()
|
||||||
|
|
||||||
if s.store != nil && s.s3Cfg != nil {
|
if s.store != nil && s.s3Cfg != nil {
|
||||||
return s.store, nil
|
return s.store, nil
|
||||||
|
|||||||
@@ -134,6 +134,30 @@ func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// blockingDumper 可控延迟的 dumper,用于测试异步行为
|
||||||
|
type blockingDumper struct {
|
||||||
|
blockCh chan struct{}
|
||||||
|
data []byte
|
||||||
|
restErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *blockingDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||||
|
select {
|
||||||
|
case <-d.blockCh:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(d.data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *blockingDumper) Restore(_ context.Context, data io.Reader) error {
|
||||||
|
if d.restErr != nil {
|
||||||
|
return d.restErr
|
||||||
|
}
|
||||||
|
_, _ = io.ReadAll(data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type mockObjectStore struct {
|
type mockObjectStore struct {
|
||||||
objects map[string][]byte
|
objects map[string][]byte
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -179,7 +203,7 @@ func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
func newTestBackupService(repo *mockSettingRepo, dumper DBDumper, store *mockObjectStore) *BackupService {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Database: config.DatabaseConfig{
|
Database: config.DatabaseConfig{
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
@@ -361,9 +385,9 @@ func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
|||||||
svc := newTestBackupService(repo, dumper, store)
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
// 手动设置 backingUp 标志
|
// 手动设置 backingUp 标志
|
||||||
svc.mu.Lock()
|
svc.opMu.Lock()
|
||||||
svc.backingUp = true
|
svc.backingUp = true
|
||||||
svc.mu.Unlock()
|
svc.opMu.Unlock()
|
||||||
|
|
||||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||||
@@ -526,3 +550,154 @@ func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, cfg)
|
require.Nil(t, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Async Backup Tests ───
|
||||||
|
|
||||||
|
func TestStartBackup_ReturnsImmediately(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
record, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "running", record.Status)
|
||||||
|
require.NotEmpty(t, record.ID)
|
||||||
|
|
||||||
|
// 释放 dumper 让后台完成
|
||||||
|
close(dumper.blockCh)
|
||||||
|
svc.wg.Wait()
|
||||||
|
|
||||||
|
// 验证最终状态
|
||||||
|
final, err := svc.GetBackupRecord(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "completed", final.Status)
|
||||||
|
require.Greater(t, final.SizeBytes, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartBackup_ConcurrentBlocked(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
// 第一次启动
|
||||||
|
_, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 第二次应被阻塞
|
||||||
|
_, err = svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||||
|
|
||||||
|
close(dumper.blockCh)
|
||||||
|
svc.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartBackup_ShuttingDown(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{dumpData: []byte("data")}, newMockObjectStore())
|
||||||
|
|
||||||
|
svc.shuttingDown.Store(true)
|
||||||
|
|
||||||
|
_, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoverStaleRecords(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
// 模拟一条孤立的 running 记录
|
||||||
|
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||||
|
ID: "stale-1",
|
||||||
|
Status: "running",
|
||||||
|
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
// 模拟一条孤立的恢复中记录
|
||||||
|
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||||
|
ID: "stale-2",
|
||||||
|
Status: "completed",
|
||||||
|
RestoreStatus: "running",
|
||||||
|
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
|
||||||
|
svc.recoverStaleRecords()
|
||||||
|
|
||||||
|
r1, _ := svc.GetBackupRecord(context.Background(), "stale-1")
|
||||||
|
require.Equal(t, "failed", r1.Status)
|
||||||
|
require.Contains(t, r1.ErrorMsg, "server restart")
|
||||||
|
|
||||||
|
r2, _ := svc.GetBackupRecord(context.Background(), "stale-2")
|
||||||
|
require.Equal(t, "failed", r2.RestoreStatus)
|
||||||
|
require.Contains(t, r2.RestoreError, "server restart")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdown(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
_, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Stop 应该等待备份完成
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
svc.Stop()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 短暂等待确认 Stop 还在等待
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Fatal("Stop returned before backup finished")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// 预期:Stop 还在等待
|
||||||
|
}
|
||||||
|
|
||||||
|
// 释放备份
|
||||||
|
close(dumper.blockCh)
|
||||||
|
|
||||||
|
// 现在 Stop 应该完成
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 预期
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Stop did not return after backup finished")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartRestore_Async(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||||
|
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
// 先创建一个备份(同步方式)
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 异步恢复
|
||||||
|
restored, err := svc.StartRestore(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "running", restored.RestoreStatus)
|
||||||
|
|
||||||
|
svc.wg.Wait()
|
||||||
|
|
||||||
|
// 验证最终状态
|
||||||
|
final, err := svc.GetBackupRecord(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "completed", final.RestoreStatus)
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,9 +21,6 @@ var (
|
|||||||
// 带捕获组的版本提取正则
|
// 带捕获组的版本提取正则
|
||||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
||||||
|
|
||||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
|
||||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
|
||||||
|
|
||||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||||
systemPromptThreshold = 0.5
|
systemPromptThreshold = 0.5
|
||||||
)
|
)
|
||||||
@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userIDPattern.MatchString(userID) {
|
if ParseMetadataUserID(userID) == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
|
|||||||
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
|
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
|
||||||
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
|
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
|
||||||
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
|
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
|
||||||
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
|
return ExtractCLIVersion(ua)
|
||||||
if len(matches) >= 2 {
|
|
||||||
return matches[1]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
|
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -15,14 +14,17 @@ const (
|
|||||||
claudeLockWaitTime = 200 * time.Millisecond
|
claudeLockWaitTime = 200 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// ClaudeTokenCache token cache interface.
|
||||||
type ClaudeTokenCache = GeminiTokenCache
|
type ClaudeTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||||
type ClaudeTokenProvider struct {
|
type ClaudeTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache ClaudeTokenCache
|
tokenCache ClaudeTokenCache
|
||||||
oauthService *OAuthService
|
oauthService *OAuthService
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClaudeTokenProvider(
|
func NewClaudeTokenProvider(
|
||||||
@@ -31,13 +33,25 @@ func NewClaudeTokenProvider(
|
|||||||
oauthService *OAuthService,
|
oauthService *OAuthService,
|
||||||
) *ClaudeTokenProvider {
|
) *ClaudeTokenProvider {
|
||||||
return &ClaudeTokenProvider{
|
return &ClaudeTokenProvider{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
|
refreshPolicy: ClaudeProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取有效的 access_token
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken returns a valid access_token.
|
||||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
cacheKey := ClaudeTokenCacheKey(account)
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||||
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||||
|
|
||||||
// 2. 如果即将过期则刷新
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
if needsRefresh && p.tokenCache != nil {
|
|
||||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
|
||||||
if lockErr == nil && locked {
|
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
|
||||||
return token, nil
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
// 从数据库获取最新账户信息
|
refreshFailed = true
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
} else if result.LockHeld {
|
||||||
if err == nil && fresh != nil {
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
account = fresh
|
time.Sleep(claudeLockWaitTime)
|
||||||
}
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
return token, nil
|
||||||
if p.oauthService == nil {
|
|
||||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
|
||||||
refreshFailed = true // 无法刷新,标记失败
|
|
||||||
} else {
|
|
||||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
|
||||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
|
||||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
|
||||||
} else {
|
|
||||||
// 构建新 credentials,保留原有字段
|
|
||||||
newCredentials := make(map[string]any)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
newCredentials["token_type"] = tokenInfo.TokenType
|
|
||||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
|
||||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
if tokenInfo.Scope != "" {
|
|
||||||
newCredentials["scope"] = tokenInfo.Scope
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if lockErr != nil {
|
|
||||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
|
||||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
|
||||||
|
|
||||||
// 检查 ctx 是否已取消
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return "", ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从数据库获取最新账户信息
|
|
||||||
if p.accountRepo != nil {
|
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
|
||||||
if err == nil && fresh != nil {
|
|
||||||
account = fresh
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
|
|
||||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
|
||||||
if p.oauthService == nil {
|
|
||||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
|
||||||
refreshFailed = true
|
|
||||||
} else {
|
|
||||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
|
||||||
refreshFailed = true
|
|
||||||
} else {
|
|
||||||
// 构建新 credentials,保留原有字段
|
|
||||||
newCredentials := make(map[string]any)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
newCredentials["token_type"] = tokenInfo.TokenType
|
|
||||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
|
||||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
if tokenInfo.Scope != "" {
|
|
||||||
newCredentials["scope"] = tokenInfo.Scope
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
} else if lockErr != nil {
|
||||||
|
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||||
|
} else {
|
||||||
time.Sleep(claudeLockWaitTime)
|
time.Sleep(claudeLockWaitTime)
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetCredential("access_token")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if refreshFailed {
|
if refreshFailed {
|
||||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
if p.refreshPolicy.FailureTTL > 0 {
|
||||||
ttl = time.Minute
|
ttl = p.refreshPolicy.FailureTTL
|
||||||
|
} else {
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||||
} else if expiresAt != nil {
|
} else if expiresAt != nil {
|
||||||
until := time.Until(*expiresAt)
|
until := time.Until(*expiresAt)
|
||||||
|
|||||||
@@ -148,6 +148,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
|
||||||
|
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get group usage summary: %w", err)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||||
data, err := s.cache.GetDashboardStats(ctx)
|
data, err := s.cache.GetDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -335,6 +344,14 @@ func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime
|
|||||||
return ranking, nil
|
return ranking, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
stats, err := s.usageRepo.GetUserBreakdownStats(ctx, startTime, endTime, dim, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get user breakdown stats: %w", err)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ const (
|
|||||||
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
||||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||||
|
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
|
||||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
|
|||||||
@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
|
|||||||
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@@ -440,7 +440,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t
|
|||||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, acc)
|
require.Nil(t, acc)
|
||||||
require.Contains(t, err.Error(), "no available accounts")
|
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
|
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
|
||||||
@@ -1073,7 +1073,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.
|
|||||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, acc)
|
require.Nil(t, acc)
|
||||||
require.Contains(t, err.Error(), "no available accounts")
|
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||||
@@ -1734,7 +1734,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
|||||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, acc)
|
require.Nil(t, acc)
|
||||||
require.Contains(t, err.Error(), "no available accounts")
|
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
|
t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
|
||||||
@@ -2290,7 +2290,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Contains(t, err.Error(), "no available accounts")
|
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
|
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
|
||||||
|
|||||||
@@ -369,3 +369,54 @@ func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T)
|
|||||||
require.Equal(t, 1, billingRepo.calls)
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
require.Equal(t, 0, usageRepo.calls)
|
require.Equal(t, 0, usageRepo.calls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_ReasoningEffortPersisted(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
effort := "max"
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "effort_test",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
},
|
||||||
|
Model: "claude-opus-4-6",
|
||||||
|
Duration: time.Second,
|
||||||
|
ReasoningEffort: &effort,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1},
|
||||||
|
User: &User{ID: 1},
|
||||||
|
Account: &Account{ID: 1},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
|
||||||
|
require.Equal(t, "max", *usageRepo.lastLog.ReasoningEffort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_ReasoningEffortNil(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "no_effort_test",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1},
|
||||||
|
User: &User{ID: 1},
|
||||||
|
Account: &Account{ID: 1},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Nil(t, usageRepo.lastLog.ReasoningEffort)
|
||||||
|
}
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ type ParsedRequest struct {
|
|||||||
Messages []any // messages 数组
|
Messages []any // messages 数组
|
||||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||||
|
OutputEffort string // output_config.effort(Claude API 的推理强度控制)
|
||||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||||
|
|
||||||
@@ -116,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
parsed.ThinkingEnabled = true
|
parsed.ThinkingEnabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// output_config.effort: Claude API 的推理强度控制参数
|
||||||
|
parsed.OutputEffort = strings.TrimSpace(gjson.Get(jsonStr, "output_config.effort").String())
|
||||||
|
|
||||||
// max_tokens: 仅接受整数值
|
// max_tokens: 仅接受整数值
|
||||||
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
|
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
|
||||||
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
||||||
@@ -747,6 +751,21 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeClaudeOutputEffort normalizes Claude's output_config.effort value.
|
||||||
|
// Returns nil for empty or unrecognized values.
|
||||||
|
func NormalizeClaudeOutputEffort(raw string) *string {
|
||||||
|
value := strings.ToLower(strings.TrimSpace(raw))
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch value {
|
||||||
|
case "low", "medium", "high", "max":
|
||||||
|
return &value
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// =========================
|
// =========================
|
||||||
// Thinking Budget Rectifier
|
// Thinking Budget Rectifier
|
||||||
// =========================
|
// =========================
|
||||||
|
|||||||
@@ -972,6 +972,76 @@ func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_OutputEffort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantEffort string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "output_config.effort present",
|
||||||
|
body: `{"model":"claude-opus-4-6","output_config":{"effort":"medium"},"messages":[]}`,
|
||||||
|
wantEffort: "medium",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "output_config.effort max",
|
||||||
|
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
|
||||||
|
wantEffort: "max",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "output_config without effort",
|
||||||
|
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
|
||||||
|
wantEffort: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no output_config",
|
||||||
|
body: `{"model":"claude-opus-4-6","messages":[]}`,
|
||||||
|
wantEffort: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "effort with whitespace trimmed",
|
||||||
|
body: `{"model":"claude-opus-4-6","output_config":{"effort":" high "},"messages":[]}`,
|
||||||
|
wantEffort: "high",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.wantEffort, parsed.OutputEffort)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOutputEffort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want *string
|
||||||
|
}{
|
||||||
|
{"low", strPtr("low")},
|
||||||
|
{"medium", strPtr("medium")},
|
||||||
|
{"high", strPtr("high")},
|
||||||
|
{"max", strPtr("max")},
|
||||||
|
{"LOW", strPtr("low")},
|
||||||
|
{"Max", strPtr("max")},
|
||||||
|
{" medium ", strPtr("medium")},
|
||||||
|
{"", nil},
|
||||||
|
{"unknown", nil},
|
||||||
|
{"xhigh", nil},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := NormalizeClaudeOutputEffort(tt.input)
|
||||||
|
if tt.want == nil {
|
||||||
|
require.Nil(t, got)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, got)
|
||||||
|
require.Equal(t, *tt.want, *got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
|
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
|
||||||
data := buildLargeJSON()
|
data := buildLargeJSON()
|
||||||
b.SetBytes(int64(len(data)))
|
b.SetBytes(int64(len(data)))
|
||||||
|
|||||||
@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
|
|||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var (
|
var (
|
||||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
|
||||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||||
|
|
||||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||||
@@ -346,6 +345,9 @@ var systemBlockFilterPrefixes = []string{
|
|||||||
"x-anthropic-billing-header",
|
"x-anthropic-billing-header",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrNoAvailableAccounts 表示没有可用的账号
|
||||||
|
var ErrNoAvailableAccounts = errors.New("no available accounts")
|
||||||
|
|
||||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||||
|
|
||||||
@@ -492,6 +494,7 @@ type ForwardResult struct {
|
|||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
FirstTokenMs *int // 首字时间(流式请求)
|
FirstTokenMs *int // 首字时间(流式请求)
|
||||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||||
|
ReasoningEffort *string
|
||||||
|
|
||||||
// 图片生成计费字段(图片生成模型使用)
|
// 图片生成计费字段(图片生成模型使用)
|
||||||
ImageCount int // 生成的图片数量
|
ImageCount int // 生成的图片数量
|
||||||
@@ -640,8 +643,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
|
|
||||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||||
if parsed.MetadataUserID != "" {
|
if parsed.MetadataUserID != "" {
|
||||||
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
|
if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
|
||||||
return match[1]
|
return uid.SessionID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1022,13 +1025,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
|||||||
sessionID = generateSessionUUID(seed)
|
sessionID = generateSessionUUID(seed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefer the newer format that includes account_uuid (if present),
|
// 根据指纹 UA 版本选择输出格式
|
||||||
// otherwise fall back to the legacy Claude Code format.
|
var uaVersion string
|
||||||
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
|
if fp != nil {
|
||||||
if accountUUID != "" {
|
uaVersion = ExtractCLIVersion(fp.UserAgent)
|
||||||
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
|
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
|
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
|
||||||
|
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
|
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
|
||||||
@@ -1204,7 +1207,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(accounts) == 0 {
|
if len(accounts) == 0 {
|
||||||
return nil, errors.New("no available accounts")
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
@@ -1552,7 +1555,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(candidates) == 0 {
|
if len(candidates) == 0 {
|
||||||
return nil, errors.New("no available accounts")
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
|
|
||||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||||
@@ -1641,7 +1644,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("no available accounts")
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||||
@@ -2851,9 +2854,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if selected == nil {
|
if selected == nil {
|
||||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
||||||
if requestedModel != "" {
|
if requestedModel != "" {
|
||||||
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
|
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||||
}
|
}
|
||||||
return nil, errors.New("no available accounts")
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
@@ -3089,9 +3092,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if selected == nil {
|
if selected == nil {
|
||||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true)
|
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true)
|
||||||
if requestedModel != "" {
|
if requestedModel != "" {
|
||||||
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
|
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||||
}
|
}
|
||||||
return nil, errors.New("no available accounts")
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
@@ -5529,7 +5532,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||||
accountUUID := account.GetExtraString("account_uuid")
|
accountUUID := account.GetExtraString("account_uuid")
|
||||||
if accountUUID != "" && fp.ClientID != "" {
|
if accountUUID != "" && fp.ClientID != "" {
|
||||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
|
||||||
body = newBody
|
body = newBody
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -7126,6 +7129,8 @@ type RecordUsageInput struct {
|
|||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
|
InboundEndpoint string // 入站端点(客户端请求路径)
|
||||||
|
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||||
@@ -7523,6 +7528,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
|
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||||
|
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
@@ -7603,6 +7611,8 @@ type RecordUsageLongContextInput struct {
|
|||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
|
InboundEndpoint string // 入站端点(客户端请求路径)
|
||||||
|
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||||
@@ -7699,6 +7709,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
|
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||||
|
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
@@ -8147,7 +8160,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
accountUUID := account.GetExtraString("account_uuid")
|
accountUUID := account.GetExtraString("account_uuid")
|
||||||
if accountUUID != "" && fp.ClientID != "" {
|
if accountUUID != "" && fp.ClientID != "" {
|
||||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
|
||||||
body = newBody
|
body = newBody
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3235,7 +3235,7 @@ func cleanToolSchema(schema any) any {
|
|||||||
for key, value := range v {
|
for key, value := range v {
|
||||||
// 跳过不支持的字段
|
// 跳过不支持的字段
|
||||||
if key == "$schema" || key == "$id" || key == "$ref" ||
|
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||||||
key == "additionalProperties" || key == "minLength" ||
|
key == "additionalProperties" || key == "patternProperties" || key == "minLength" ||
|
||||||
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
|
|||||||
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@@ -15,10 +15,14 @@ const (
|
|||||||
geminiTokenCacheSkew = 5 * time.Minute
|
geminiTokenCacheSkew = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
|
||||||
type GeminiTokenProvider struct {
|
type GeminiTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache GeminiTokenCache
|
tokenCache GeminiTokenCache
|
||||||
geminiOAuthService *GeminiOAuthService
|
geminiOAuthService *GeminiOAuthService
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGeminiTokenProvider(
|
func NewGeminiTokenProvider(
|
||||||
@@ -30,9 +34,21 @@ func NewGeminiTokenProvider(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
geminiOAuthService: geminiOAuthService,
|
geminiOAuthService: geminiOAuthService,
|
||||||
|
refreshPolicy: GeminiProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
// 2) Refresh if needed (pre-expiry skew).
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
||||||
if needsRefresh && p.tokenCache != nil {
|
|
||||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
|
||||||
if err == nil && locked {
|
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
|
||||||
|
|
||||||
// Re-check after lock (another worker may have refreshed).
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew)
|
||||||
return token, nil
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
} else if result.LockHeld {
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
if err == nil && fresh != nil {
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
account = fresh
|
return token, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID)
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
|
}
|
||||||
if p.geminiOAuthService == nil {
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
return "", errors.New("gemini oauth service not configured")
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
}
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account)
|
if lockErr == nil && locked {
|
||||||
if err != nil {
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
return "", err
|
} else if lockErr != nil {
|
||||||
}
|
slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||||
newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
_ = p.accountRepo.Update(ctx, account)
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
// project_id is optional now:
|
// project_id is optional now:
|
||||||
// - If present: will use Code Assist API (requires project_id)
|
// - If present: use Code Assist API (requires project_id)
|
||||||
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
|
// - If absent: use AI Studio API with OAuth token.
|
||||||
// Auto-detect project_id only if explicitly enabled via a credential flag
|
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
||||||
|
|
||||||
if projectID == "" && autoDetectProjectID {
|
if projectID == "" && autoDetectProjectID {
|
||||||
if p.geminiOAuthService == nil {
|
if p.geminiOAuthService == nil {
|
||||||
return accessToken, nil // Fallback to AI Studio API mode
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
@@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetCredential("access_token")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if expiresAt != nil {
|
if expiresAt != nil {
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
|
|||||||
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键
|
||||||
|
func (r *GeminiTokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return GeminiTokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
||||||
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
||||||
}
|
}
|
||||||
@@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
for k, v := range account.Credentials {
|
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newCredentials, nil
|
return newCredentials, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
|
|||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
parsed := &ParsedRequest{
|
parsed := &ParsedRequest{
|
||||||
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
|
||||||
System: "You are a helpful assistant.",
|
System: "You are a helpful assistant.",
|
||||||
HasSystem: true,
|
HasSystem: true,
|
||||||
Messages: []any{
|
Messages: []any{
|
||||||
@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
|
|||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
parsed := &ParsedRequest{
|
parsed := &ParsedRequest{
|
||||||
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
|
||||||
Messages: []any{
|
Messages: []any{
|
||||||
map[string]any{"role": "user", "content": "hello"},
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
},
|
},
|
||||||
@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
|
|||||||
"metadata session_id should take priority over SessionContext")
|
"metadata session_id should take priority over SessionContext")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
|
||||||
|
System: "You are a helpful assistant.",
|
||||||
|
HasSystem: true,
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := svc.GenerateSessionHash(parsed)
|
||||||
|
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
|
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
|||||||
@@ -64,8 +64,10 @@ type Group struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
AccountGroups []AccountGroup
|
AccountGroups []AccountGroup
|
||||||
AccountCount int64
|
AccountCount int64
|
||||||
|
ActiveAccountCount int64
|
||||||
|
RateLimitedAccountCount int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) IsActive() bool {
|
func (g *Group) IsActive() bool {
|
||||||
|
|||||||
131
backend/internal/service/group_capacity_service.go
Normal file
131
backend/internal/service/group_capacity_service.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupCapacitySummary holds aggregated capacity for a single group.
|
||||||
|
type GroupCapacitySummary struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
ConcurrencyUsed int `json:"concurrency_used"`
|
||||||
|
ConcurrencyMax int `json:"concurrency_max"`
|
||||||
|
SessionsUsed int `json:"sessions_used"`
|
||||||
|
SessionsMax int `json:"sessions_max"`
|
||||||
|
RPMUsed int `json:"rpm_used"`
|
||||||
|
RPMMax int `json:"rpm_max"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupCapacityService aggregates per-group capacity from runtime data.
|
||||||
|
type GroupCapacityService struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
groupRepo GroupRepository
|
||||||
|
concurrencyService *ConcurrencyService
|
||||||
|
sessionLimitCache SessionLimitCache
|
||||||
|
rpmCache RPMCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGroupCapacityService creates a new GroupCapacityService.
|
||||||
|
func NewGroupCapacityService(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
groupRepo GroupRepository,
|
||||||
|
concurrencyService *ConcurrencyService,
|
||||||
|
sessionLimitCache SessionLimitCache,
|
||||||
|
rpmCache RPMCache,
|
||||||
|
) *GroupCapacityService {
|
||||||
|
return &GroupCapacityService{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
concurrencyService: concurrencyService,
|
||||||
|
sessionLimitCache: sessionLimitCache,
|
||||||
|
rpmCache: rpmCache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllGroupCapacity returns capacity summary for all active groups.
|
||||||
|
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
|
||||||
|
groups, err := s.groupRepo.ListActive(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]GroupCapacitySummary, 0, len(groups))
|
||||||
|
for i := range groups {
|
||||||
|
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
|
||||||
|
if err != nil {
|
||||||
|
// Skip groups with errors, return partial results
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cap.GroupID = groups[i].ID
|
||||||
|
results = append(results, cap)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
|
||||||
|
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return GroupCapacitySummary{}, err
|
||||||
|
}
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return GroupCapacitySummary{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect account IDs and config values
|
||||||
|
accountIDs := make([]int64, 0, len(accounts))
|
||||||
|
sessionTimeouts := make(map[int64]time.Duration)
|
||||||
|
var concurrencyMax, sessionsMax, rpmMax int
|
||||||
|
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
accountIDs = append(accountIDs, acc.ID)
|
||||||
|
concurrencyMax += acc.Concurrency
|
||||||
|
|
||||||
|
if ms := acc.GetMaxSessions(); ms > 0 {
|
||||||
|
sessionsMax += ms
|
||||||
|
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 5 * time.Minute
|
||||||
|
}
|
||||||
|
sessionTimeouts[acc.ID] = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if rpm := acc.GetBaseRPM(); rpm > 0 {
|
||||||
|
rpmMax += rpm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch query runtime data from Redis
|
||||||
|
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||||
|
|
||||||
|
var sessionsMap map[int64]int
|
||||||
|
if sessionsMax > 0 && s.sessionLimitCache != nil {
|
||||||
|
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rpmMap map[int64]int
|
||||||
|
if rpmMax > 0 && s.rpmCache != nil {
|
||||||
|
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate
|
||||||
|
var concurrencyUsed, sessionsUsed, rpmUsed int
|
||||||
|
for _, id := range accountIDs {
|
||||||
|
concurrencyUsed += concurrencyMap[id]
|
||||||
|
if sessionsMap != nil {
|
||||||
|
sessionsUsed += sessionsMap[id]
|
||||||
|
}
|
||||||
|
if rpmMap != nil {
|
||||||
|
rpmUsed += rpmMap[id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return GroupCapacitySummary{
|
||||||
|
ConcurrencyUsed: concurrencyUsed,
|
||||||
|
ConcurrencyMax: concurrencyMax,
|
||||||
|
SessionsUsed: sessionsUsed,
|
||||||
|
SessionsMax: sessionsMax,
|
||||||
|
RPMUsed: rpmUsed,
|
||||||
|
RPMMax: rpmMax,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -27,7 +27,7 @@ type GroupRepository interface {
|
|||||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||||
|
|
||||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
|
||||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||||
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
||||||
@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取账号数量
|
// 获取账号数量
|
||||||
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
|
accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get account count: %w", err)
|
return nil, fmt.Errorf("get account count: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,10 +19,6 @@ import (
|
|||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
var (
|
var (
|
||||||
// 匹配 user_id 格式:
|
|
||||||
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
|
|
||||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
|
|
||||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
|
|
||||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||||
)
|
)
|
||||||
@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RewriteUserID 重写body中的metadata.user_id
|
// RewriteUserID 重写body中的metadata.user_id
|
||||||
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
// 支持旧拼接格式和新 JSON 格式的 user_id 解析,
|
||||||
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
// 根据 fingerprintUA 版本选择输出格式。
|
||||||
//
|
//
|
||||||
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||||
// 避免重新序列化导致 thinking 块等内容被修改。
|
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||||
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
|
||||||
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 匹配格式:
|
// 解析 user_id(兼容旧拼接格式和新 JSON 格式)
|
||||||
// 旧格式: user_{64位hex}_account__session_{uuid}
|
parsed := ParseMetadataUserID(userID)
|
||||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
|
if parsed == nil {
|
||||||
matches := userIDRegex.FindStringSubmatch(userID)
|
|
||||||
if matches == nil {
|
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matches[1] = account UUID (可能为空), matches[2] = session UUID
|
sessionTail := parsed.SessionID // 原始session UUID
|
||||||
sessionTail := matches[2] // 原始session UUID
|
|
||||||
|
|
||||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||||
newSessionHash := generateUUIDFromSeed(seed)
|
newSessionHash := generateUUIDFromSeed(seed)
|
||||||
|
|
||||||
// 构建新的user_id
|
// 根据客户端版本选择输出格式
|
||||||
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
|
version := ExtractCLIVersion(fingerprintUA)
|
||||||
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
metadata["user_id"] = newUserID
|
||||||
|
|
||||||
@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
//
|
//
|
||||||
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||||
// 避免重新序列化导致 thinking 块等内容被修改。
|
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
|
||||||
// 先执行常规的 RewriteUserID 逻辑
|
// 先执行常规的 RewriteUserID 逻辑
|
||||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newBody, err
|
return newBody, err
|
||||||
}
|
}
|
||||||
@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查找 _session_ 的位置,替换其后的内容
|
// 解析已重写的 user_id
|
||||||
const sessionMarker = "_session_"
|
uidParsed := ParseMetadataUserID(userID)
|
||||||
idx := strings.LastIndex(userID, sessionMarker)
|
if uidParsed == nil {
|
||||||
if idx == -1 {
|
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
|
// 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
|
||||||
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
|
version := ExtractCLIVersion(fingerprintUA)
|
||||||
|
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
|
||||||
|
|
||||||
slog.Debug("session_id_masking_applied",
|
slog.Debug("session_id_masking_applied",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user