mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
103 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0236b97d49 | ||
|
|
26f6b1eeff | ||
|
|
dc447ccebe | ||
|
|
7ec29638f4 | ||
|
|
4c9562af20 | ||
|
|
71942fd322 | ||
|
|
550b979ac5 | ||
|
|
3878a5a46f | ||
|
|
e443a6a1ea | ||
|
|
963494ec6f | ||
|
|
525cdb8830 | ||
|
|
a6764e82f2 | ||
|
|
8027531d07 | ||
|
|
30706355a4 | ||
|
|
dfe99507b8 | ||
|
|
c1717c9a6c | ||
|
|
1fd1a58a7a | ||
|
|
fad07507be | ||
|
|
a20c211162 | ||
|
|
9f6ab6b817 | ||
|
|
bf3d6c0e6e | ||
|
|
241023f3fc | ||
|
|
1292c44b41 | ||
|
|
b4fce47049 | ||
|
|
e7780cd8c8 | ||
|
|
af96c8ea53 | ||
|
|
7d26b81075 | ||
|
|
b8ada63ac3 | ||
|
|
cfaac12af1 | ||
|
|
6028efd26c | ||
|
|
62a566ef2c | ||
|
|
94419f434c | ||
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
bd9d2671d7 | ||
|
|
62b40636e0 | ||
|
|
eeff451bc5 | ||
|
|
56fcb20f94 | ||
|
|
7134266acf | ||
|
|
2e4ac88ad9 | ||
|
|
51547fa216 | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
aa6047c460 | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 |
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
|||||||
# Go 源代码文件
|
# Go 源代码文件
|
||||||
*.go text eol=lf
|
*.go text eol=lf
|
||||||
|
|
||||||
|
# 前端 源代码文件
|
||||||
|
*.ts text eol=lf
|
||||||
|
*.tsx text eol=lf
|
||||||
|
*.js text eol=lf
|
||||||
|
*.jsx text eol=lf
|
||||||
|
*.vue text eol=lf
|
||||||
|
|
||||||
# Shell 脚本
|
# Shell 脚本
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
|||||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -271,3 +271,36 @@ jobs:
|
|||||||
parse_mode: "Markdown",
|
parse_mode: "Markdown",
|
||||||
disable_web_page_preview: true
|
disable_web_page_preview: true
|
||||||
}')"
|
}')"
|
||||||
|
|
||||||
|
sync-version-file:
|
||||||
|
needs: [release]
|
||||||
|
if: ${{ needs.release.result == 'success' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout default branch
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.repository.default_branch }}
|
||||||
|
|
||||||
|
- name: Sync VERSION file to released tag
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
VERSION=${{ github.event.inputs.tag }}
|
||||||
|
VERSION=${VERSION#v}
|
||||||
|
else
|
||||||
|
VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
|
fi
|
||||||
|
|
||||||
|
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
|
||||||
|
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
|
||||||
|
echo "VERSION file already matches $VERSION"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$VERSION" > backend/cmd/server/VERSION
|
||||||
|
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
|
||||||
|
git push origin HEAD:${{ github.event.repository.default_branch }}
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -76,6 +78,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -89,6 +93,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -102,6 +108,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
11
Dockerfile
11
Dockerfile
@@ -92,6 +92,7 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
|||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
|
su-exec \
|
||||||
libpq \
|
libpq \
|
||||||
zstd-libs \
|
zstd-libs \
|
||||||
lz4-libs \
|
lz4-libs \
|
||||||
@@ -120,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||||
|
|
||||||
# Switch to non-root user
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
USER sub2api
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
# Expose port (can be overridden by SERVER_PORT env var)
|
# Expose port (can be overridden by SERVER_PORT env var)
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
@@ -130,5 +132,6 @@ EXPOSE 8080
|
|||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
# Run the application
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ RUN apk add --no-cache \
|
|||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
curl \
|
curl \
|
||||||
|
su-exec \
|
||||||
libpq \
|
libpq \
|
||||||
zstd-libs \
|
zstd-libs \
|
||||||
lz4-libs \
|
lz4-libs \
|
||||||
@@ -47,11 +48,15 @@ COPY sub2api /app/sub2api
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||||
|
|
||||||
USER sub2api
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
30
README.md
30
README.md
@@ -8,27 +8,31 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||||
|
|
||||||
English | [中文](README_CN.md)
|
English | [中文](README_CN.md)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
Try Sub2API online: **https://demo.sub2api.org/**
|
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||||
|
|
||||||
| Email | Password |
|
| Email | Password |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -41,6 +45,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
- **Admin Dashboard** - Web interface for monitoring and management
|
- **Admin Dashboard** - Web interface for monitoring and management
|
||||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||||
|
|
||||||
|
## Don't Want to Self-Host?
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
## Ecosystem
|
## Ecosystem
|
||||||
|
|
||||||
Community projects that extend or integrate with Sub2API:
|
Community projects that extend or integrate with Sub2API:
|
||||||
@@ -61,10 +74,15 @@ Community projects that extend or integrate with Sub2API:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Documentation
|
## Nginx Reverse Proxy Note
|
||||||
|
|
||||||
- Dependency Security: `docs/dependency-security.md`
|
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
33
README_CN.md
33
README_CN.md
@@ -8,27 +8,30 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API 网关平台 - 订阅配额分发管理**
|
**AI API 网关平台 - 订阅配额分发管理**
|
||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||||
---
|
---
|
||||||
|
|
||||||
## 在线体验
|
## 在线体验
|
||||||
|
|
||||||
体验地址:**https://v2.pincc.ai/**
|
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||||
|
|
||||||
| 邮箱 | 密码 |
|
| 邮箱 | 密码 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||||
|
|
||||||
## 核心功能
|
## 核心功能
|
||||||
|
|
||||||
@@ -41,6 +44,15 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- **管理后台** - Web 界面进行监控和管理
|
- **管理后台** - Web 界面进行监控和管理
|
||||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||||
|
|
||||||
|
## 不想自建?试试官方中转
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
## 生态项目
|
## 生态项目
|
||||||
|
|
||||||
围绕 Sub2API 的社区扩展与集成项目:
|
围绕 Sub2API 的社区扩展与集成项目:
|
||||||
@@ -61,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 文档
|
## Nginx 反向代理注意事项
|
||||||
|
|
||||||
- 依赖安全:`docs/dependency-security.md`
|
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||||
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## OpenAI Responses 兼容注意事项
|
|
||||||
|
|
||||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
|
||||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||||
@@ -143,6 +142,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
rpmCache := repository.NewRPMCache(redisClient)
|
rpmCache := repository.NewRPMCache(redisClient)
|
||||||
|
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||||
|
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
dataManagementService := service.NewDataManagementService()
|
dataManagementService := service.NewDataManagementService()
|
||||||
|
|||||||
@@ -716,6 +716,7 @@ var (
|
|||||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
{Name: "request_id", Type: field.TypeString, Size: 64},
|
{Name: "request_id", Type: field.TypeString, Size: 64},
|
||||||
{Name: "model", Type: field.TypeString, Size: 100},
|
{Name: "model", Type: field.TypeString, Size: 100},
|
||||||
|
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
||||||
@@ -755,31 +756,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -788,32 +789,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -828,17 +829,17 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id_created_at",
|
Name: "usagelog_group_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
|
|||||||
id *int64
|
id *int64
|
||||||
request_id *string
|
request_id *string
|
||||||
model *string
|
model *string
|
||||||
|
upstream_model *string
|
||||||
input_tokens *int
|
input_tokens *int
|
||||||
addinput_tokens *int
|
addinput_tokens *int
|
||||||
output_tokens *int
|
output_tokens *int
|
||||||
@@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
|
|||||||
m.model = nil
|
m.model = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) SetUpstreamModel(s string) {
|
||||||
|
m.upstream_model = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
|
||||||
|
v := m.upstream_model
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.UpstreamModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) ClearUpstreamModel() {
|
||||||
|
m.upstream_model = nil
|
||||||
|
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
|
||||||
|
func (m *UsageLogMutation) UpstreamModelCleared() bool {
|
||||||
|
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetUpstreamModel resets all changes to the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) ResetUpstreamModel() {
|
||||||
|
m.upstream_model = nil
|
||||||
|
delete(m.clearedFields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (m *UsageLogMutation) SetGroupID(i int64) {
|
func (m *UsageLogMutation) SetGroupID(i int64) {
|
||||||
m.group = &i
|
m.group = &i
|
||||||
@@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 32)
|
fields := make([]string, 0, 33)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.model != nil {
|
if m.model != nil {
|
||||||
fields = append(fields, usagelog.FieldModel)
|
fields = append(fields, usagelog.FieldModel)
|
||||||
}
|
}
|
||||||
|
if m.upstream_model != nil {
|
||||||
|
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
if m.group != nil {
|
if m.group != nil {
|
||||||
fields = append(fields, usagelog.FieldGroupID)
|
fields = append(fields, usagelog.FieldGroupID)
|
||||||
}
|
}
|
||||||
@@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.RequestID()
|
return m.RequestID()
|
||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
return m.Model()
|
return m.Model()
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
return m.UpstreamModel()
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
return m.GroupID()
|
return m.GroupID()
|
||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
@@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldRequestID(ctx)
|
return m.OldRequestID(ctx)
|
||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
return m.OldModel(ctx)
|
return m.OldModel(ctx)
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
return m.OldUpstreamModel(ctx)
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
return m.OldGroupID(ctx)
|
return m.OldGroupID(ctx)
|
||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
@@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetModel(v)
|
m.SetModel(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetUpstreamModel(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
v, ok := value.(int64)
|
v, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
|||||||
// mutation.
|
// mutation.
|
||||||
func (m *UsageLogMutation) ClearedFields() []string {
|
func (m *UsageLogMutation) ClearedFields() []string {
|
||||||
var fields []string
|
var fields []string
|
||||||
|
if m.FieldCleared(usagelog.FieldUpstreamModel) {
|
||||||
|
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
if m.FieldCleared(usagelog.FieldGroupID) {
|
if m.FieldCleared(usagelog.FieldGroupID) {
|
||||||
fields = append(fields, usagelog.FieldGroupID)
|
fields = append(fields, usagelog.FieldGroupID)
|
||||||
}
|
}
|
||||||
@@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
|
|||||||
// error if the field is not defined in the schema.
|
// error if the field is not defined in the schema.
|
||||||
func (m *UsageLogMutation) ClearField(name string) error {
|
func (m *UsageLogMutation) ClearField(name string) error {
|
||||||
switch name {
|
switch name {
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
m.ClearUpstreamModel()
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
m.ClearGroupID()
|
m.ClearGroupID()
|
||||||
return nil
|
return nil
|
||||||
@@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
m.ResetModel()
|
m.ResetModel()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
m.ResetUpstreamModel()
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
m.ResetGroupID()
|
m.ResetGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -821,92 +821,96 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
|
||||||
|
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
|
||||||
|
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||||
|
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||||
usagelogDescInputTokens := usagelogFields[7].Descriptor()
|
usagelogDescInputTokens := usagelogFields[8].Descriptor()
|
||||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||||
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
|
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
|
||||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||||
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
|
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
|
||||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||||
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
|
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
|
||||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||||
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
|
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
|
||||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||||
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
|
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
|
||||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||||
usagelogDescInputCost := usagelogFields[13].Descriptor()
|
usagelogDescInputCost := usagelogFields[14].Descriptor()
|
||||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||||
usagelogDescOutputCost := usagelogFields[14].Descriptor()
|
usagelogDescOutputCost := usagelogFields[15].Descriptor()
|
||||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||||
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
|
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
|
||||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||||
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
|
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
|
||||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||||
usagelogDescTotalCost := usagelogFields[17].Descriptor()
|
usagelogDescTotalCost := usagelogFields[18].Descriptor()
|
||||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||||
usagelogDescActualCost := usagelogFields[18].Descriptor()
|
usagelogDescActualCost := usagelogFields[19].Descriptor()
|
||||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||||
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
|
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
|
||||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||||
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
usagelogDescBillingType := usagelogFields[22].Descriptor()
|
||||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||||
// usagelogDescStream is the schema descriptor for stream field.
|
// usagelogDescStream is the schema descriptor for stream field.
|
||||||
usagelogDescStream := usagelogFields[22].Descriptor()
|
usagelogDescStream := usagelogFields[23].Descriptor()
|
||||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||||
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
usagelogDescUserAgent := usagelogFields[26].Descriptor()
|
||||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||||
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
usagelogDescIPAddress := usagelogFields[27].Descriptor()
|
||||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||||
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
usagelogDescImageCount := usagelogFields[28].Descriptor()
|
||||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||||
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
usagelogDescImageSize := usagelogFields[29].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
usagelogDescMediaType := usagelogFields[30].Descriptor()
|
||||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||||
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
|
||||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
field.String("model").
|
field.String("model").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
NotEmpty(),
|
NotEmpty(),
|
||||||
|
// UpstreamModel stores the actual upstream model name when model mapping
|
||||||
|
// is applied. NULL means no mapping — the requested model was used as-is.
|
||||||
|
field.String("upstream_model").
|
||||||
|
MaxLen(100).
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
field.Int64("group_id").
|
field.Int64("group_id").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ type UsageLog struct {
|
|||||||
RequestID string `json:"request_id,omitempty"`
|
RequestID string `json:"request_id,omitempty"`
|
||||||
// Model holds the value of the "model" field.
|
// Model holds the value of the "model" field.
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
// UpstreamModel holds the value of the "upstream_model" field.
|
||||||
|
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||||
// GroupID holds the value of the "group_id" field.
|
// GroupID holds the value of the "group_id" field.
|
||||||
GroupID *int64 `json:"group_id,omitempty"`
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
// SubscriptionID holds the value of the "subscription_id" field.
|
// SubscriptionID holds the value of the "subscription_id" field.
|
||||||
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Model = value.String
|
_m.Model = value.String
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UpstreamModel = new(string)
|
||||||
|
*_m.UpstreamModel = value.String
|
||||||
|
}
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||||
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString("model=")
|
builder.WriteString("model=")
|
||||||
builder.WriteString(_m.Model)
|
builder.WriteString(_m.Model)
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.UpstreamModel; v != nil {
|
||||||
|
builder.WriteString("upstream_model=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
if v := _m.GroupID; v != nil {
|
if v := _m.GroupID; v != nil {
|
||||||
builder.WriteString("group_id=")
|
builder.WriteString("group_id=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ const (
|
|||||||
FieldRequestID = "request_id"
|
FieldRequestID = "request_id"
|
||||||
// FieldModel holds the string denoting the model field in the database.
|
// FieldModel holds the string denoting the model field in the database.
|
||||||
FieldModel = "model"
|
FieldModel = "model"
|
||||||
|
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||||
|
FieldUpstreamModel = "upstream_model"
|
||||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||||
FieldGroupID = "group_id"
|
FieldGroupID = "group_id"
|
||||||
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
||||||
@@ -135,6 +137,7 @@ var Columns = []string{
|
|||||||
FieldAccountID,
|
FieldAccountID,
|
||||||
FieldRequestID,
|
FieldRequestID,
|
||||||
FieldModel,
|
FieldModel,
|
||||||
|
FieldUpstreamModel,
|
||||||
FieldGroupID,
|
FieldGroupID,
|
||||||
FieldSubscriptionID,
|
FieldSubscriptionID,
|
||||||
FieldInputTokens,
|
FieldInputTokens,
|
||||||
@@ -179,6 +182,8 @@ var (
|
|||||||
RequestIDValidator func(string) error
|
RequestIDValidator func(string) error
|
||||||
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
||||||
ModelValidator func(string) error
|
ModelValidator func(string) error
|
||||||
|
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||||
|
UpstreamModelValidator func(string) error
|
||||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||||
DefaultInputTokens int
|
DefaultInputTokens int
|
||||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||||
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByUpstreamModel orders the results by the upstream_model field.
|
||||||
|
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByGroupID orders the results by the group_id field.
|
// ByGroupID orders the results by the group_id field.
|
||||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||||
|
|||||||
@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
|
||||||
|
func UpstreamModel(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||||
func GroupID(v int64) predicate.UsageLog {
|
func GroupID(v int64) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||||
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelGT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelGTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelLT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelLTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelContains(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelIsNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNotNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelEqualFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelContainsFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||||
func GroupIDEQ(v int64) predicate.UsageLog {
|
func GroupIDEQ(v int64) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||||
|
|||||||
@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
|
||||||
|
_c.mutation.SetUpstreamModel(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
||||||
_c.mutation.SetGroupID(v)
|
_c.mutation.SetGroupID(v)
|
||||||
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _c.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.InputTokens(); !ok {
|
if _, ok := _c.mutation.InputTokens(); !ok {
|
||||||
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
||||||
}
|
}
|
||||||
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
_node.Model = value
|
_node.Model = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
_node.UpstreamModel = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.InputTokens(); ok {
|
if value, ok := _c.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
_node.InputTokens = value
|
_node.InputTokens = value
|
||||||
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldUpstreamModel, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldUpstreamModel)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
|
||||||
|
u.SetNull(usagelog.FieldUpstreamModel)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldGroupID, v)
|
u.Set(usagelog.FieldGroupID, v)
|
||||||
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetUpstreamModel(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetUpstreamModel(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
|
||||||
|
_u.mutation.SetUpstreamModel(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
|
||||||
|
_u.mutation.ClearUpstreamModel()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
||||||
_u.mutation.SetGroupID(v)
|
_u.mutation.SetGroupID(v)
|
||||||
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.UserAgent(); ok {
|
if v, ok := _u.mutation.UserAgent(); ok {
|
||||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.Model(); ok {
|
if value, ok := _u.mutation.Model(); ok {
|
||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UpstreamModelCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.InputTokens(); ok {
|
if value, ok := _u.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.SetUpstreamModel(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ClearUpstreamModel()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
||||||
_u.mutation.SetGroupID(v)
|
_u.mutation.SetGroupID(v)
|
||||||
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.UserAgent(); ok {
|
if v, ok := _u.mutation.UserAgent(); ok {
|
||||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if value, ok := _u.mutation.Model(); ok {
|
if value, ok := _u.mutation.Model(); ok {
|
||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UpstreamModelCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.InputTokens(); ok {
|
if value, ok := _u.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
|||||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||||
@@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
|
|||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
|
||||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
|
||||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||||
@@ -98,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
|||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
|
||||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
|
||||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
|
||||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
@@ -238,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
|
||||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@@ -273,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@@ -326,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
|
||||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
|
|||||||
@@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
// Claude Haiku → Sonnet(无 Haiku 支持)
|
// Claude Haiku → Sonnet(无 Haiku 支持)
|
||||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
"claude-haiku-4-5": "claude-sonnet-4-6",
|
||||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
|
||||||
// Gemini 2.5 白名单
|
// Gemini 2.5 白名单
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
|
|||||||
@@ -165,6 +165,8 @@ type AccountWithConcurrency struct {
|
|||||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||||
|
|
||||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||||
item := AccountWithConcurrency{
|
item := AccountWithConcurrency{
|
||||||
Account: dto.AccountFromService(account),
|
Account: dto.AccountFromService(account),
|
||||||
@@ -226,7 +228,20 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
var groupID int64
|
var groupID int64
|
||||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
if groupIDStr == accountListGroupUngroupedQueryValue {
|
||||||
|
groupID = service.AccountListGroupUngrouped
|
||||||
|
} else {
|
||||||
|
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if parsedGroupID < 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = parsedGroupID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||||
@@ -1496,7 +1511,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUsage handles getting account usage information
|
// GetUsage handles getting account usage information
|
||||||
// GET /api/v1/admin/accounts/:id/usage
|
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1504,7 +1519,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
source := c.DefaultQuery("source", "active")
|
||||||
|
|
||||||
|
var usage *service.UsageInfo
|
||||||
|
if source == "passive" {
|
||||||
|
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||||
|
} else {
|
||||||
|
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
|||||||
adminSvc := newStubAdminService()
|
adminSvc := newStubAdminService()
|
||||||
|
|
||||||
userHandler := NewUserHandler(adminSvc, nil)
|
userHandler := NewUserHandler(adminSvc, nil)
|
||||||
groupHandler := NewGroupHandler(adminSvc)
|
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||||
proxyHandler := NewProxyHandler(adminSvc)
|
proxyHandler := NewProxyHandler(adminSvc)
|
||||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -98,12 +98,12 @@ func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
|||||||
expireDays = *req.ExpireDays
|
expireDays = *req.ExpireDays
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
response.Success(c, record)
|
response.Accepted(c, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||||
@@ -196,9 +196,10 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
response.Success(c, gin.H{"restored": true})
|
response.Accepted(c, record)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -272,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID, accountID, groupID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
modelSource := usagestats.ModelSourceRequested
|
||||||
var requestType *int16
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
var billingType *int8
|
var billingType *int8
|
||||||
@@ -296,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
groupID = id
|
groupID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
|
||||||
|
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||||
|
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelSource = rawModelSource
|
||||||
|
}
|
||||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -322,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
@@ -604,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
|||||||
c.Header("X-Snapshot-Cache", "miss")
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
response.Success(c, payload)
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||||
|
// GET /api/v1/admin/dashboard/user-breakdown
|
||||||
|
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||||
|
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
|
dim := usagestats.UserBreakdownDimension{}
|
||||||
|
if v := c.Query("group_id"); v != "" {
|
||||||
|
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
dim.GroupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dim.Model = c.Query("model")
|
||||||
|
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
|
||||||
|
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||||
|
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dim.ModelType = rawModelSource
|
||||||
|
dim.Endpoint = c.Query("endpoint")
|
||||||
|
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if v := c.Query("limit"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||||
|
c.Request.Context(), startTime, endTime, dim, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"users": stats,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsValidModelSource(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
repo := &dashboardUsageRepoCapture{
|
repo := &dashboardUsageRepoCapture{
|
||||||
|
|||||||
@@ -0,0 +1,229 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- mock repo ---
|
||||||
|
|
||||||
|
type userBreakdownRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
capturedDim usagestats.UserBreakdownDimension
|
||||||
|
capturedLimit int
|
||||||
|
result []usagestats.UserBreakdownItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||||
|
_ context.Context, _, _ time.Time,
|
||||||
|
dim usagestats.UserBreakdownDimension, limit int,
|
||||||
|
) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
r.capturedDim = dim
|
||||||
|
r.capturedLimit = limit
|
||||||
|
if r.result != nil {
|
||||||
|
return r.result, nil
|
||||||
|
}
|
||||||
|
return []usagestats.UserBreakdownItem{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
h := NewDashboardHandler(svc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- tests ---
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||||
|
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 100, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
// limit > 200 should fall back to default 50
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{
|
||||||
|
result: []usagestats.UserBreakdownItem{
|
||||||
|
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||||
|
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
StartDate string `json:"start_date"`
|
||||||
|
EndDate string `json:"end_date"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, resp.Code)
|
||||||
|
require.Len(t, resp.Data.Users, 2)
|
||||||
|
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||||
|
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||||
|
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||||
|
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||||
|
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||||
|
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, resp.Data.Users)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
}
|
||||||
@@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
|
|||||||
APIKeyID int64 `json:"api_key_id"`
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
|
ModelSource string `json:"model_source,omitempty"`
|
||||||
RequestType *int16 `json:"request_type"`
|
RequestType *int16 `json:"request_type"`
|
||||||
Stream *bool `json:"stream"`
|
Stream *bool `json:"stream"`
|
||||||
BillingType *int8 `json:"billing_type"`
|
BillingType *int8 `json:"billing_type"`
|
||||||
@@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
startTime, endTime time.Time,
|
startTime, endTime time.Time,
|
||||||
userID, apiKeyID, accountID, groupID int64,
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
modelSource string,
|
||||||
requestType *int16,
|
requestType *int16,
|
||||||
stream *bool,
|
stream *bool,
|
||||||
billingType *int8,
|
billingType *int8,
|
||||||
@@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
|
|||||||
APIKeyID: apiKeyID,
|
APIKeyID: apiKeyID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
|
ModelSource: usagestats.NormalizeModelSource(modelSource),
|
||||||
RequestType: requestType,
|
RequestType: requestType,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
})
|
})
|
||||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, hit, err
|
return nil, hit, err
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
|
|||||||
filters.APIKeyID,
|
filters.APIKeyID,
|
||||||
filters.AccountID,
|
filters.AccountID,
|
||||||
filters.GroupID,
|
filters.GroupID,
|
||||||
|
usagestats.ModelSourceRequested,
|
||||||
filters.RequestType,
|
filters.RequestType,
|
||||||
filters.Stream,
|
filters.Stream,
|
||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -16,7 +17,9 @@ import (
|
|||||||
|
|
||||||
// GroupHandler handles admin group management
|
// GroupHandler handles admin group management
|
||||||
type GroupHandler struct {
|
type GroupHandler struct {
|
||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
|
dashboardService *service.DashboardService
|
||||||
|
groupCapacityService *service.GroupCapacityService
|
||||||
}
|
}
|
||||||
|
|
||||||
type optionalLimitField struct {
|
type optionalLimitField struct {
|
||||||
@@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupHandler creates a new admin group handler
|
// NewGroupHandler creates a new admin group handler
|
||||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||||
return &GroupHandler{
|
return &GroupHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
|
dashboardService: dashboardService,
|
||||||
|
groupCapacityService: groupCapacityService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
|||||||
_ = groupID // TODO: implement actual stats
|
_ = groupID // TODO: implement actual stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||||
|
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||||
|
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||||
|
userTZ := c.Query("timezone")
|
||||||
|
now := timezone.NowInUserLocation(userTZ)
|
||||||
|
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||||
|
|
||||||
|
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group usage summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||||
|
// GET /api/v1/admin/groups/capacity-summary
|
||||||
|
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||||
|
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group capacity summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroupAPIKeys handles getting API keys in a group
|
// GetGroupAPIKeys handles getting API keys in a group
|
||||||
// GET /api/v1/admin/groups/:id/api-keys
|
// GET /api/v1/admin/groups/:id/api-keys
|
||||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||||
|
|||||||
@@ -977,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOverloadCooldownSettings 获取529过载冷却配置
|
||||||
|
// GET /api/v1/admin/settings/overload-cooldown
|
||||||
|
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.OverloadCooldownSettings{
|
||||||
|
Enabled: settings.Enabled,
|
||||||
|
CooldownMinutes: settings.CooldownMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
|
||||||
|
type UpdateOverloadCooldownSettingsRequest struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CooldownMinutes int `json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOverloadCooldownSettings 更新529过载冷却配置
|
||||||
|
// PUT /api/v1/admin/settings/overload-cooldown
|
||||||
|
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
|
||||||
|
var req UpdateOverloadCooldownSettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.OverloadCooldownSettings{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
CooldownMinutes: req.CooldownMinutes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.OverloadCooldownSettings{
|
||||||
|
Enabled: updatedSettings.Enabled,
|
||||||
|
CooldownMinutes: updatedSettings.CooldownMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// GetStreamTimeoutSettings 获取流超时处理配置
|
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||||
// GET /api/v1/admin/settings/stream-timeout
|
// GET /api/v1/admin/settings/stream-timeout
|
||||||
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||||
|
|||||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
|
platform := c.Query("platform")
|
||||||
|
|
||||||
// Parse sorting parameters
|
// Parse sorting parameters
|
||||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||||
|
|
||||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.MCPXMLInject,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
AccountCount: g.AccountCount,
|
AccountCount: g.AccountCount,
|
||||||
SortOrder: g.SortOrder,
|
ActiveAccountCount: g.ActiveAccountCount,
|
||||||
|
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||||
|
SortOrder: g.SortOrder,
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
@@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
AccountID: l.AccountID,
|
AccountID: l.AccountID,
|
||||||
RequestID: l.RequestID,
|
RequestID: l.RequestID,
|
||||||
Model: l.Model,
|
Model: l.Model,
|
||||||
|
UpstreamModel: l.UpstreamModel,
|
||||||
ServiceTier: l.ServiceTier,
|
ServiceTier: l.ServiceTier,
|
||||||
ReasoningEffort: l.ReasoningEffort,
|
ReasoningEffort: l.ReasoningEffort,
|
||||||
InboundEndpoint: l.InboundEndpoint,
|
InboundEndpoint: l.InboundEndpoint,
|
||||||
|
|||||||
@@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
|
|||||||
Items []SoraS3Profile `json:"items"`
|
Items []SoraS3Profile `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||||
|
type OverloadCooldownSettings struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CooldownMinutes int `json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
type StreamTimeoutSettings struct {
|
type StreamTimeoutSettings struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@@ -122,9 +122,11 @@ type AdminGroup struct {
|
|||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
AccountCount int64 `json:"account_count,omitempty"`
|
AccountCount int64 `json:"account_count,omitempty"`
|
||||||
|
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||||
|
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||||
|
|
||||||
// 分组排序
|
// 分组排序
|
||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
@@ -332,6 +334,9 @@ type UsageLog struct {
|
|||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||||
|
// Omitted when no mapping was applied (requested model was used as-is).
|
||||||
|
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||||
ServiceTier *string `json:"service_tier,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty"`
|
||||||
// ReasoningEffort is the request's reasoning effort level.
|
// ReasoningEffort is the request's reasoning effort level.
|
||||||
|
|||||||
@@ -1219,6 +1219,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1227,6 +1231,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
|
||||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
|||||||
return []byte(`{
|
return []byte(`{
|
||||||
"model":"claude-3-5-sonnet-20241022",
|
"model":"claude-3-5-sonnet-20241022",
|
||||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||||
}`)
|
}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
System: []any{
|
System: []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||||
}
|
}
|
||||||
|
|
||||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||||
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
"system": []any{
|
"system": []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||||
})
|
})
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||||
|
|||||||
@@ -593,6 +593,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, message := mapGeminiUpstreamError(statusCode)
|
status, message := mapGeminiUpstreamError(statusCode)
|
||||||
googleError(c, status, message)
|
googleError(c, status, message)
|
||||||
|
|||||||
@@ -1435,6 +1435,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1443,6 +1447,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -484,6 +484,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
|
|||||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@@ -345,6 +345,12 @@ func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Conte
|
|||||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ const (
|
|||||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||||
var defaultUserAgentVersion = "1.20.4"
|
var defaultUserAgentVersion = "1.20.5"
|
||||||
|
|
||||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
var defaultClientSecret = "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"
|
var defaultClientSecret = "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"
|
||||||
|
|||||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||||
}
|
}
|
||||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||||
}
|
}
|
||||||
if SessionTTL != 30*time.Minute {
|
if SessionTTL != 30*time.Minute {
|
||||||
|
|||||||
@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
|
||||||
var systemBlockFilterPrefixes = []string{
|
|
||||||
"x-anthropic-billing-header",
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
|
||||||
func filterSystemBlockByPrefix(text string) string {
|
|
||||||
for _, prefix := range systemBlockFilterPrefixes {
|
|
||||||
if strings.HasPrefix(text, prefix) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(sysStr, "You are Antigravity") {
|
if strings.Contains(sysStr, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
filtered := filterOpenCodePrompt(sysStr)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(block.Text, "You are Antigravity") {
|
if strings.Contains(block.Text, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
filtered := filterOpenCodePrompt(block.Text)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package antigravity
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||||
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
system json.RawMessage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "system array",
|
||||||
|
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system string",
|
||||||
|
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claudeReq := &ClaudeRequest{
|
||||||
|
Model: "claude-3-5-sonnet-latest",
|
||||||
|
System: tt.system,
|
||||||
|
Messages: []ClaudeMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var req V1InternalRequest
|
||||||
|
require.NoError(t, json.Unmarshal(body, &req))
|
||||||
|
require.NotNil(t, req.Request.SystemInstruction)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, part := range req.Request.SystemInstruction.Parts {
|
||||||
|
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1008,3 +1008,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
|||||||
// Should default to image/png when media_type is empty.
|
// Should default to image/png when media_type is empty.
|
||||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// normalizeToolParameters tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNormalizeToolParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input json.RawMessage
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input",
|
||||||
|
input: nil,
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: json.RawMessage(``),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null input",
|
||||||
|
input: json.RawMessage(`null`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object without properties",
|
||||||
|
input: json.RawMessage(`{"type":"object"}`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with properties",
|
||||||
|
input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||||
|
expected: `{"type":"object","properties":{"city":{"type":"string"}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-object type",
|
||||||
|
input: json.RawMessage(`{"type":"string"}`),
|
||||||
|
expected: `{"type":"string"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with additional fields preserved",
|
||||||
|
input: json.RawMessage(`{"type":"object","required":["name"]}`),
|
||||||
|
expected: `{"type":"object","required":["name"],"properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON passthrough",
|
||||||
|
input: json.RawMessage(`not json`),
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := normalizeToolParameters(tt.input)
|
||||||
|
if tt.name == "invalid JSON passthrough" {
|
||||||
|
assert.Equal(t, tt.expected, string(result))
|
||||||
|
} else {
|
||||||
|
assert.JSONEq(t, tt.expected, string(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// Parameters must have "properties" field after normalization.
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.Contains(t, params, "properties")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "simple_tool", Description: "A tool"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.JSONEq(t, `"object"`, string(params["type"]))
|
||||||
|
assert.JSONEq(t, `{}`, string(params["properties"]))
|
||||||
|
}
|
||||||
|
|||||||
@@ -409,8 +409,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
|||||||
Type: "function",
|
Type: "function",
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Description: t.Description,
|
Description: t.Description,
|
||||||
Parameters: t.InputSchema,
|
Parameters: normalizeToolParameters(t.InputSchema),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||||
|
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||||
|
//
|
||||||
|
// - nil/empty → {"type":"object","properties":{}}
|
||||||
|
// - type=object without properties → adds "properties": {}
|
||||||
|
// - otherwise → returned unchanged
|
||||||
|
func normalizeToolParameters(schema json.RawMessage) json.RawMessage {
|
||||||
|
if len(schema) == 0 || string(schema) == "null" {
|
||||||
|
return json.RawMessage(`{"type":"object","properties":{}}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(schema, &m); err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := m["type"]
|
||||||
|
if string(typ) != `"object"` {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["properties"]; ok {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
m["properties"] = json.RawMessage(`{}`)
|
||||||
|
out, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accepted 返回异步接受响应 (HTTP 202)
|
||||||
|
func Accepted(c *gin.Context, data any) {
|
||||||
|
c.JSON(http.StatusAccepted, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: "accepted",
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Error 返回错误响应
|
// Error 返回错误响应
|
||||||
func Error(c *gin.Context, statusCode int, message string) {
|
func Error(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, Response{
|
c.JSON(statusCode, Response{
|
||||||
|
|||||||
@@ -3,6 +3,28 @@ package usagestats
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelSourceRequested = "requested"
|
||||||
|
ModelSourceUpstream = "upstream"
|
||||||
|
ModelSourceMapping = "mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsValidModelSource(source string) bool {
|
||||||
|
switch source {
|
||||||
|
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeModelSource(source string) string {
|
||||||
|
if IsValidModelSource(source) {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
return ModelSourceRequested
|
||||||
|
}
|
||||||
|
|
||||||
// DashboardStats 仪表盘统计
|
// DashboardStats 仪表盘统计
|
||||||
type DashboardStats struct {
|
type DashboardStats struct {
|
||||||
// 用户统计
|
// 用户统计
|
||||||
@@ -90,6 +112,13 @@ type EndpointStat struct {
|
|||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GroupUsageSummary represents today's and cumulative cost for a single group.
|
||||||
|
type GroupUsageSummary struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
TodayCost float64 `json:"today_cost"`
|
||||||
|
TotalCost float64 `json:"total_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
// GroupStat represents usage statistics for a single group
|
// GroupStat represents usage statistics for a single group
|
||||||
type GroupStat struct {
|
type GroupStat struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
@@ -129,6 +158,25 @@ type UserSpendingRankingResponse struct {
|
|||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
|
||||||
|
type UserBreakdownItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||||
|
type UserBreakdownDimension struct {
|
||||||
|
GroupID int64 // filter by group_id (>0 to enable)
|
||||||
|
Model string // filter by model name (non-empty to enable)
|
||||||
|
ModelType string // "requested", "upstream", or "mapping"
|
||||||
|
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||||
|
EndpointType string // "inbound", "upstream", or "path"
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||||
type APIKeyUsageTrendPoint struct {
|
type APIKeyUsageTrendPoint struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
|
|||||||
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package usagestats
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsValidModelSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "requested", source: ModelSourceRequested, want: true},
|
||||||
|
{name: "upstream", source: ModelSourceUpstream, want: true},
|
||||||
|
{name: "mapping", source: ModelSourceMapping, want: true},
|
||||||
|
{name: "invalid", source: "foobar", want: false},
|
||||||
|
{name: "empty", source: "", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := IsValidModelSource(tc.source); got != tc.want {
|
||||||
|
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeModelSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
|
||||||
|
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
|
||||||
|
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
|
||||||
|
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
|
||||||
|
{name: "empty falls back", source: "", want: ModelSourceRequested},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := NormalizeModelSource(tc.source); got != tc.want {
|
||||||
|
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,6 +56,7 @@ var schedulerNeutralExtraKeyPrefixes = []string{
|
|||||||
"codex_secondary_",
|
"codex_secondary_",
|
||||||
"codex_5h_",
|
"codex_5h_",
|
||||||
"codex_7d_",
|
"codex_7d_",
|
||||||
|
"passive_usage_",
|
||||||
}
|
}
|
||||||
|
|
||||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||||
@@ -473,7 +474,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(dbaccount.NameContainsFold(search))
|
q = q.Where(dbaccount.NameContainsFold(search))
|
||||||
}
|
}
|
||||||
if groupID > 0 {
|
if groupID == service.AccountListGroupUngrouped {
|
||||||
|
q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups()))
|
||||||
|
} else if groupID > 0 {
|
||||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
accType string
|
accType string
|
||||||
status string
|
status string
|
||||||
search string
|
search string
|
||||||
|
groupID int64
|
||||||
wantCount int
|
wantCount int
|
||||||
validate func(accounts []service.Account)
|
validate func(accounts []service.Account)
|
||||||
}{
|
}{
|
||||||
@@ -265,6 +266,21 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
s.Require().Contains(accounts[0].Name, "alpha")
|
s.Require().Contains(accounts[0].Name, "alpha")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "filter_by_ungrouped",
|
||||||
|
setup: func(client *dbent.Client) {
|
||||||
|
group := mustCreateGroup(s.T(), client, &service.Group{Name: "g-ungrouped"})
|
||||||
|
grouped := mustCreateAccount(s.T(), client, &service.Account{Name: "grouped-account"})
|
||||||
|
mustCreateAccount(s.T(), client, &service.Account{Name: "ungrouped-account"})
|
||||||
|
mustBindAccountToGroup(s.T(), client, grouped.ID, group.ID, 1)
|
||||||
|
},
|
||||||
|
groupID: service.AccountListGroupUngrouped,
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []service.Account) {
|
||||||
|
s.Require().Equal("ungrouped-account", accounts[0].Name)
|
||||||
|
s.Require().Empty(accounts[0].GroupIDs)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -277,7 +293,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
|
|
||||||
tt.setup(client)
|
tt.setup(client)
|
||||||
|
|
||||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(accounts, tt.wantCount)
|
s.Require().Len(accounts, tt.wantCount)
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
|||||||
|
|
||||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||||
|
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
|
||||||
data, err := io.ReadAll(body)
|
data, err := io.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("read body: %w", err)
|
return 0, fmt.Errorf("read body: %w", err)
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ const (
|
|||||||
billingCacheTTL = 5 * time.Minute
|
billingCacheTTL = 5 * time.Minute
|
||||||
billingCacheJitter = 30 * time.Second
|
billingCacheJitter = 30 * time.Second
|
||||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||||
|
|
||||||
|
// Rate limit window durations — must match service.RateLimitWindow* constants.
|
||||||
|
rateLimitWindow5h = 5 * time.Hour
|
||||||
|
rateLimitWindow1d = 24 * time.Hour
|
||||||
|
rateLimitWindow7d = 7 * 24 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||||
@@ -90,17 +95,40 @@ var (
|
|||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
|
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
|
||||||
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
|
// with window expiration checking. If a window has expired, its usage is reset to cost
|
||||||
|
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
|
||||||
|
// IncrementRateLimitUsage semantics.
|
||||||
|
//
|
||||||
|
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
|
||||||
updateRateLimitUsageScript = redis.NewScript(`
|
updateRateLimitUsageScript = redis.NewScript(`
|
||||||
local exists = redis.call('EXISTS', KEYS[1])
|
local exists = redis.call('EXISTS', KEYS[1])
|
||||||
if exists == 0 then
|
if exists == 0 then
|
||||||
return 0
|
return 0
|
||||||
end
|
end
|
||||||
local cost = tonumber(ARGV[1])
|
local cost = tonumber(ARGV[1])
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
|
local now = tonumber(ARGV[3])
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
|
local win5h = tonumber(ARGV[4])
|
||||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
|
local win1d = tonumber(ARGV[5])
|
||||||
|
local win7d = tonumber(ARGV[6])
|
||||||
|
|
||||||
|
-- Helper: check if window is expired and update usage + window accordingly
|
||||||
|
-- Returns nothing, modifies the hash in-place.
|
||||||
|
local function update_window(usage_field, window_field, window_duration)
|
||||||
|
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
|
||||||
|
if w == 0 or (now - w) >= window_duration then
|
||||||
|
-- Window expired or never started: reset usage to cost, start new window
|
||||||
|
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
|
||||||
|
redis.call('HSET', KEYS[1], window_field, tostring(now))
|
||||||
|
else
|
||||||
|
-- Window still valid: accumulate
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
update_window('usage_5h', 'window_5h', win5h)
|
||||||
|
update_window('usage_1d', 'window_1d', win1d)
|
||||||
|
update_window('usage_7d', 'window_7d', win7d)
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
@@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data
|
|||||||
|
|
||||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||||
key := billingRateLimitKey(keyID)
|
key := billingRateLimitKey(keyID)
|
||||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
|
now := time.Now().Unix()
|
||||||
|
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
|
||||||
|
cost,
|
||||||
|
int(rateLimitCacheTTL.Seconds()),
|
||||||
|
now,
|
||||||
|
int(rateLimitWindow5h.Seconds()),
|
||||||
|
int(rateLimitWindow1d.Seconds()),
|
||||||
|
int(rateLimitWindow7d.Seconds()),
|
||||||
|
).Result()
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
total, active, _ := r.GetAccountCount(ctx, out.ID)
|
||||||
out.AccountCount = count
|
out.AccountCount = total
|
||||||
|
out.ActiveAccountCount = active
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
|||||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := range outGroups {
|
for i := range outGroups {
|
||||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
c := counts[outGroups[i].ID]
|
||||||
|
outGroups[i].AccountCount = c.Total
|
||||||
|
outGroups[i].ActiveAccountCount = c.Active
|
||||||
|
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
|
|||||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := range outGroups {
|
for i := range outGroups {
|
||||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
c := counts[outGroups[i].ID]
|
||||||
|
outGroups[i].AccountCount = c.Total
|
||||||
|
outGroups[i].ActiveAccountCount = c.Active
|
||||||
|
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
|
|||||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := range outGroups {
|
for i := range outGroups {
|
||||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
c := counts[outGroups[i].ID]
|
||||||
|
outGroups[i].AccountCount = c.Total
|
||||||
|
outGroups[i].ActiveAccountCount = c.Active
|
||||||
|
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
|
||||||
var count int64
|
var rateLimited int64
|
||||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
err = scanSingleRow(ctx, r.sql,
|
||||||
return 0, err
|
`SELECT COUNT(*),
|
||||||
}
|
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
|
||||||
return count, nil
|
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||||
|
a.rate_limit_reset_at > NOW() OR
|
||||||
|
a.overload_until > NOW() OR
|
||||||
|
a.temp_unschedulable_until > NOW()
|
||||||
|
))
|
||||||
|
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
|
||||||
|
WHERE ag.group_id = $1`,
|
||||||
|
[]any{groupID}, &total, &active, &rateLimited)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
return affectedUserIDs, nil
|
return affectedUserIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
|
type groupAccountCounts struct {
|
||||||
counts = make(map[int64]int64, len(groupIDs))
|
Total int64
|
||||||
|
Active int64
|
||||||
|
RateLimited int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
|
||||||
|
counts = make(map[int64]groupAccountCounts, len(groupIDs))
|
||||||
if len(groupIDs) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
return counts, nil
|
return counts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(
|
rows, err := r.sql.QueryContext(
|
||||||
ctx,
|
ctx,
|
||||||
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
|
`SELECT ag.group_id,
|
||||||
|
COUNT(*) AS total,
|
||||||
|
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
|
||||||
|
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||||
|
a.rate_limit_reset_at > NOW() OR
|
||||||
|
a.overload_until > NOW() OR
|
||||||
|
a.temp_unschedulable_until > NOW()
|
||||||
|
)) AS rate_limited
|
||||||
|
FROM account_groups ag
|
||||||
|
JOIN accounts a ON a.id = ag.account_id
|
||||||
|
WHERE ag.group_id = ANY($1)
|
||||||
|
GROUP BY ag.group_id`,
|
||||||
pq.Array(groupIDs),
|
pq.Array(groupIDs),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
|||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var groupID int64
|
var groupID int64
|
||||||
var count int64
|
var c groupAccountCounts
|
||||||
if err = rows.Scan(&groupID, &count); err != nil {
|
if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
counts[groupID] = count
|
counts[groupID] = c
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
|||||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||||
s.Require().NoError(err, "GetAccountCount")
|
s.Require().NoError(err, "GetAccountCount")
|
||||||
s.Require().Equal(int64(2), count)
|
s.Require().Equal(int64(2), count)
|
||||||
}
|
}
|
||||||
@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
|||||||
}
|
}
|
||||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||||
|
|
||||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Zero(count)
|
s.Require().Zero(count)
|
||||||
}
|
}
|
||||||
@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
|||||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||||
|
|
||||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
s.Require().NoError(err, "GetAccountCount")
|
s.Require().NoError(err, "GetAccountCount")
|
||||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||||
}
|
}
|
||||||
@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
|||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Equal(int64(3), affected)
|
s.Require().Equal(int64(3), affected)
|
||||||
|
|
||||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
s.Require().Zero(count)
|
s.Require().Zero(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import (
|
|||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||||
|
|
||||||
var usageLogInsertArgTypes = [...]string{
|
var usageLogInsertArgTypes = [...]string{
|
||||||
"bigint",
|
"bigint",
|
||||||
@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
|
|||||||
"bigint",
|
"bigint",
|
||||||
"text",
|
"text",
|
||||||
"text",
|
"text",
|
||||||
|
"text",
|
||||||
"bigint",
|
"bigint",
|
||||||
"bigint",
|
"bigint",
|
||||||
"integer",
|
"integer",
|
||||||
@@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5,
|
$1, $2, $3, $4, $5, $6,
|
||||||
$6, $7,
|
$7, $8,
|
||||||
$8, $9, $10, $11,
|
$9, $10, $11, $12,
|
||||||
$12, $13,
|
$13, $14,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$15, $16, $17, $18, $19, $20,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(keys)*38)
|
args := make([]any, 0, len(keys)*39)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, key := range keys {
|
for idx, key := range keys {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(preparedList)*38)
|
args := make([]any, 0, len(preparedList)*39)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, prepared := range preparedList {
|
for idx, prepared := range preparedList {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
account_id,
|
account_id,
|
||||||
request_id,
|
request_id,
|
||||||
model,
|
model,
|
||||||
|
upstream_model,
|
||||||
group_id,
|
group_id,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5,
|
$1, $2, $3, $4, $5, $6,
|
||||||
$6, $7,
|
$7, $8,
|
||||||
$8, $9, $10, $11,
|
$9, $10, $11, $12,
|
||||||
$12, $13,
|
$13, $14,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$15, $16, $17, $18, $19, $20,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
`, prepared.args...)
|
`, prepared.args...)
|
||||||
@@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
reasoningEffort := nullString(log.ReasoningEffort)
|
reasoningEffort := nullString(log.ReasoningEffort)
|
||||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||||
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||||
|
upstreamModel := nullString(log.UpstreamModel)
|
||||||
|
|
||||||
var requestIDArg any
|
var requestIDArg any
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
@@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
log.AccountID,
|
log.AccountID,
|
||||||
requestIDArg,
|
requestIDArg,
|
||||||
log.Model,
|
log.Model,
|
||||||
|
upstreamModel,
|
||||||
groupID,
|
groupID,
|
||||||
subscriptionID,
|
subscriptionID,
|
||||||
log.InputTokens,
|
log.InputTokens,
|
||||||
@@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
|||||||
|
|
||||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||||
|
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
|
||||||
|
// source: requested | upstream | mapping.
|
||||||
|
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
|
||||||
|
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
|
||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
|
modelExpr := resolveModelDimensionExpression(source)
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
SELECT
|
SELECT
|
||||||
model,
|
%s as model,
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
@@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
%s
|
%s
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`, actualCostExpr)
|
`, modelExpr, actualCostExpr)
|
||||||
|
|
||||||
args := []any{startTime, endTime}
|
args := []any{startTime, endTime}
|
||||||
if userID > 0 {
|
if userID > 0 {
|
||||||
@@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||||
args = append(args, int16(*billingType))
|
args = append(args, int16(*billingType))
|
||||||
}
|
}
|
||||||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -3000,6 +3022,132 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension.
|
||||||
|
func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
COALESCE(ul.user_id, 0) as user_id,
|
||||||
|
COALESCE(u.email, '') as email,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||||
|
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||||
|
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
|
||||||
|
FROM usage_logs ul
|
||||||
|
LEFT JOIN users u ON u.id = ul.user_id
|
||||||
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
|
`
|
||||||
|
args := []any{startTime, endTime}
|
||||||
|
|
||||||
|
if dim.GroupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, dim.GroupID)
|
||||||
|
}
|
||||||
|
if dim.Model != "" {
|
||||||
|
query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
|
||||||
|
args = append(args, dim.Model)
|
||||||
|
}
|
||||||
|
if dim.Endpoint != "" {
|
||||||
|
col := resolveEndpointColumn(dim.EndpointType)
|
||||||
|
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
|
||||||
|
args = append(args, dim.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
|
||||||
|
if limit > 0 {
|
||||||
|
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
results = make([]usagestats.UserBreakdownItem, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var row usagestats.UserBreakdownItem
|
||||||
|
if err := rows.Scan(
|
||||||
|
&row.UserID,
|
||||||
|
&row.Email,
|
||||||
|
&row.Requests,
|
||||||
|
&row.TotalTokens,
|
||||||
|
&row.Cost,
|
||||||
|
&row.ActualCost,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
|
||||||
|
// todayStart is the start-of-day in the caller's timezone (UTC-based).
|
||||||
|
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
|
||||||
|
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
|
||||||
|
// or a materialized view / pre-aggregation table for cumulative costs.
|
||||||
|
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
g.id AS group_id,
|
||||||
|
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
|
||||||
|
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
|
||||||
|
FROM groups g
|
||||||
|
LEFT JOIN usage_logs ul ON ul.group_id = g.id
|
||||||
|
GROUP BY g.id
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, todayStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
var results []usagestats.GroupUsageSummary
|
||||||
|
for rows.Next() {
|
||||||
|
var row usagestats.GroupUsageSummary
|
||||||
|
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
|
||||||
|
func resolveModelDimensionExpression(modelType string) string {
|
||||||
|
switch usagestats.NormalizeModelSource(modelType) {
|
||||||
|
case usagestats.ModelSourceUpstream:
|
||||||
|
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
|
||||||
|
case usagestats.ModelSourceMapping:
|
||||||
|
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
|
||||||
|
default:
|
||||||
|
return "model"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
|
||||||
|
func resolveEndpointColumn(endpointType string) string {
|
||||||
|
switch endpointType {
|
||||||
|
case "upstream":
|
||||||
|
return "ul.upstream_endpoint"
|
||||||
|
case "path":
|
||||||
|
return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"
|
||||||
|
default:
|
||||||
|
return "ul.inbound_endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetGlobalStats gets usage statistics for all users within a time range
|
// GetGlobalStats gets usage statistics for all users within a time range
|
||||||
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||||
query := `
|
query := `
|
||||||
@@ -3740,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
accountID int64
|
accountID int64
|
||||||
requestID sql.NullString
|
requestID sql.NullString
|
||||||
model string
|
model string
|
||||||
|
upstreamModel sql.NullString
|
||||||
groupID sql.NullInt64
|
groupID sql.NullInt64
|
||||||
subscriptionID sql.NullInt64
|
subscriptionID sql.NullInt64
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@@ -3782,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&accountID,
|
&accountID,
|
||||||
&requestID,
|
&requestID,
|
||||||
&model,
|
&model,
|
||||||
|
&upstreamModel,
|
||||||
&groupID,
|
&groupID,
|
||||||
&subscriptionID,
|
&subscriptionID,
|
||||||
&inputTokens,
|
&inputTokens,
|
||||||
@@ -3894,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if upstreamEndpoint.Valid {
|
if upstreamEndpoint.Valid {
|
||||||
log.UpstreamEndpoint = &upstreamEndpoint.String
|
log.UpstreamEndpoint = &upstreamEndpoint.String
|
||||||
}
|
}
|
||||||
|
if upstreamModel.Valid {
|
||||||
|
log.UpstreamModel = &upstreamModel.String
|
||||||
|
}
|
||||||
|
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
|
|||||||
50
backend/internal/repository/usage_log_repo_breakdown_test.go
Normal file
50
backend/internal/repository/usage_log_repo_breakdown_test.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveEndpointColumn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
endpointType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"inbound", "ul.inbound_endpoint"},
|
||||||
|
{"upstream", "ul.upstream_endpoint"},
|
||||||
|
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
|
||||||
|
{"", "ul.inbound_endpoint"}, // default
|
||||||
|
{"unknown", "ul.inbound_endpoint"}, // fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.endpointType, func(t *testing.T) {
|
||||||
|
got := resolveEndpointColumn(tc.endpointType)
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveModelDimensionExpression(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
modelType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{usagestats.ModelSourceRequested, "model"},
|
||||||
|
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
|
||||||
|
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
|
||||||
|
{"", "model"},
|
||||||
|
{"invalid", "model"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.modelType, func(t *testing.T) {
|
||||||
|
got := resolveModelDimensionExpression(tc.modelType)
|
||||||
|
require.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
|||||||
log.AccountID,
|
log.AccountID,
|
||||||
log.RequestID,
|
log.RequestID,
|
||||||
log.Model,
|
log.Model,
|
||||||
|
sqlmock.AnyArg(), // upstream_model
|
||||||
sqlmock.AnyArg(), // group_id
|
sqlmock.AnyArg(), // group_id
|
||||||
sqlmock.AnyArg(), // subscription_id
|
sqlmock.AnyArg(), // subscription_id
|
||||||
log.InputTokens,
|
log.InputTokens,
|
||||||
@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
|||||||
log.Model,
|
log.Model,
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
|
sqlmock.AnyArg(),
|
||||||
log.InputTokens,
|
log.InputTokens,
|
||||||
log.OutputTokens,
|
log.OutputTokens,
|
||||||
log.CacheCreationTokens,
|
log.CacheCreationTokens,
|
||||||
@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
int64(30), // account_id
|
int64(30), // account_id
|
||||||
sql.NullString{Valid: true, String: "req-1"},
|
sql.NullString{Valid: true, String: "req-1"},
|
||||||
"gpt-5", // model
|
"gpt-5", // model
|
||||||
|
sql.NullString{}, // upstream_model
|
||||||
sql.NullInt64{}, // group_id
|
sql.NullInt64{}, // group_id
|
||||||
sql.NullInt64{}, // subscription_id
|
sql.NullInt64{}, // subscription_id
|
||||||
1, // input_tokens
|
1, // input_tokens
|
||||||
@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
int64(31),
|
int64(31),
|
||||||
sql.NullString{Valid: true, String: "req-2"},
|
sql.NullString{Valid: true, String: "req-2"},
|
||||||
"gpt-5",
|
"gpt-5",
|
||||||
|
sql.NullString{},
|
||||||
sql.NullInt64{},
|
sql.NullInt64{},
|
||||||
sql.NullInt64{},
|
sql.NullInt64{},
|
||||||
1, 2, 3, 4, 5, 6,
|
1, 2, 3, 4, 5, 6,
|
||||||
@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
int64(32),
|
int64(32),
|
||||||
sql.NullString{Valid: true, String: "req-3"},
|
sql.NullString{Valid: true, String: "req-3"},
|
||||||
"gpt-5.4",
|
"gpt-5.4",
|
||||||
|
sql.NullString{},
|
||||||
sql.NullInt64{},
|
sql.NullInt64{},
|
||||||
sql.NullInt64{},
|
sql.NullInt64{},
|
||||||
1, 2, 3, 4, 5, 6,
|
1, 2, 3, 4, 5, 6,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
|||||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
q := client.UserSubscription.Query()
|
q := client.UserSubscription.Query()
|
||||||
if userID != nil {
|
if userID != nil {
|
||||||
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
|||||||
if groupID != nil {
|
if groupID != nil {
|
||||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||||
}
|
}
|
||||||
|
if platform != "" {
|
||||||
|
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
|
||||||
|
}
|
||||||
|
|
||||||
// Status filtering with real-time expiration check
|
// Status filtering with real-time expiration check
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
|||||||
group := s.mustCreateGroup("g-list")
|
group := s.mustCreateGroup("g-list")
|
||||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||||
|
|
||||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
|
||||||
s.Require().NoError(err, "List")
|
s.Require().NoError(err, "List")
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
|||||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
|||||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
|||||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||||
})
|
})
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||||
|
|||||||
@@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
|
|||||||
return false, errors.New("not implemented")
|
return false, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return 0, 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
@@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
|
|||||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||||
@@ -1637,6 +1637,10 @@ func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -1782,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
|
|||||||
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubSettingRepo struct {
|
type stubSettingRepo struct {
|
||||||
all map[string]string
|
all map[string]string
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
|
|||||||
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||||
|
|||||||
@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||||
|
dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown)
|
||||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -226,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
{
|
{
|
||||||
groups.GET("", h.Admin.Group.List)
|
groups.GET("", h.Admin.Group.List)
|
||||||
groups.GET("/all", h.Admin.Group.GetAll)
|
groups.GET("/all", h.Admin.Group.GetAll)
|
||||||
|
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
|
||||||
|
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
|
||||||
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
|
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
|
||||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||||
groups.POST("", h.Admin.Group.Create)
|
groups.POST("", h.Admin.Group.Create)
|
||||||
@@ -399,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
||||||
|
// 529过载冷却配置
|
||||||
|
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
|
||||||
|
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
|
||||||
// 流超时处理配置
|
// 流超时处理配置
|
||||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ var (
|
|||||||
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const AccountListGroupUngrouped int64 = -1
|
||||||
|
|
||||||
type AccountRepository interface {
|
type AccountRepository interface {
|
||||||
Create(ctx context.Context, account *Account) error
|
Create(ctx context.Context, account *Account) error
|
||||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||||
|
|||||||
@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
|
|||||||
return normalized, nil
|
return normalized, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionString generates a Claude Code style session string
|
// generateSessionString generates a Claude Code style session string.
|
||||||
|
// The output format is determined by the UA version in claude.DefaultHeaders,
|
||||||
|
// ensuring consistency between the user_id format and the UA sent to upstream.
|
||||||
func generateSessionString() (string, error) {
|
func generateSessionString() (string, error) {
|
||||||
bytes := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
if _, err := rand.Read(b); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
hex64 := hex.EncodeToString(bytes)
|
hex64 := hex.EncodeToString(b)
|
||||||
sessionUUID := uuid.New().String()
|
sessionUUID := uuid.New().String()
|
||||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
|
||||||
|
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTestPayload creates a Claude Code style test request payload
|
// createTestPayload creates a Claude Code style test request payload
|
||||||
@@ -305,7 +308,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
|
||||||
|
// 403 表示账号被上游封禁,标记为 error 状态
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.sendErrorAndEnd(c, errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process SSE stream
|
// Process SSE stream
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ type UsageLogRepository interface {
|
|||||||
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||||
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||||
|
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
|
||||||
|
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||||
@@ -175,6 +177,7 @@ type AICredit struct {
|
|||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
|
Source string `json:"source,omitempty"` // "passive" or "active"
|
||||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||||
@@ -391,6 +394,9 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
||||||
s.addWindowStats(ctx, account, usage)
|
s.addWindowStats(ctx, account, usage)
|
||||||
|
|
||||||
|
// 5. 将主动查询结果同步到被动缓存,下次 passive 加载即为最新值
|
||||||
|
s.syncActiveToPassive(ctx, account.ID, usage)
|
||||||
|
|
||||||
s.tryClearRecoverableAccountError(ctx, account)
|
s.tryClearRecoverableAccountError(ctx, account)
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
@@ -407,6 +413,81 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPassiveUsage 从 Account.Extra 中的被动采样数据构建 UsageInfo,不调用外部 API。
|
||||||
|
// 仅适用于 Anthropic OAuth / SetupToken 账号。
|
||||||
|
func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||||
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get account failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复用 estimateSetupTokenUsage 构建 5h 窗口(OAuth 和 SetupToken 逻辑一致)
|
||||||
|
info := s.estimateSetupTokenUsage(account)
|
||||||
|
info.Source = "passive"
|
||||||
|
|
||||||
|
// 设置采样时间
|
||||||
|
if raw, ok := account.Extra["passive_usage_sampled_at"]; ok {
|
||||||
|
if str, ok := raw.(string); ok {
|
||||||
|
if t, err := time.Parse(time.RFC3339, str); err == nil {
|
||||||
|
info.UpdatedAt = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 7d 窗口(从被动采样数据)
|
||||||
|
util7d := parseExtraFloat64(account.Extra["passive_usage_7d_utilization"])
|
||||||
|
reset7dRaw := parseExtraFloat64(account.Extra["passive_usage_7d_reset"])
|
||||||
|
if util7d > 0 || reset7dRaw > 0 {
|
||||||
|
var resetAt *time.Time
|
||||||
|
var remaining int
|
||||||
|
if reset7dRaw > 0 {
|
||||||
|
t := time.Unix(int64(reset7dRaw), 0)
|
||||||
|
resetAt = &t
|
||||||
|
remaining = int(time.Until(t).Seconds())
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info.SevenDay = &UsageProgress{
|
||||||
|
Utilization: util7d * 100,
|
||||||
|
ResetsAt: resetAt,
|
||||||
|
RemainingSeconds: remaining,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加窗口统计
|
||||||
|
s.addWindowStats(ctx, account, info)
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncActiveToPassive 将主动查询的最新数据回写到 Extra 被动缓存,
|
||||||
|
// 这样下次被动加载时能看到最新值。
|
||||||
|
func (s *AccountUsageService) syncActiveToPassive(ctx context.Context, accountID int64, usage *UsageInfo) {
|
||||||
|
extraUpdates := make(map[string]any, 4)
|
||||||
|
|
||||||
|
if usage.FiveHour != nil {
|
||||||
|
extraUpdates["session_window_utilization"] = usage.FiveHour.Utilization / 100
|
||||||
|
}
|
||||||
|
if usage.SevenDay != nil {
|
||||||
|
extraUpdates["passive_usage_7d_utilization"] = usage.SevenDay.Utilization / 100
|
||||||
|
if usage.SevenDay.ResetsAt != nil {
|
||||||
|
extraUpdates["passive_usage_7d_reset"] = usage.SevenDay.ResetsAt.Unix()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(extraUpdates) > 0 {
|
||||||
|
extraUpdates["passive_usage_sampled_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
if err := s.accountRepo.UpdateExtra(ctx, accountID, extraUpdates); err != nil {
|
||||||
|
slog.Warn("sync_active_to_passive_failed", "account_id", accountID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
usage := &UsageInfo{UpdatedAt: &now}
|
usage := &UsageInfo{UpdatedAt: &now}
|
||||||
@@ -446,23 +527,17 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
||||||
windowStats := windowStatsFromAccountStats(stats)
|
if usage.FiveHour == nil {
|
||||||
if hasMeaningfulWindowStats(windowStats) {
|
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||||
if usage.FiveHour == nil {
|
|
||||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
|
||||||
}
|
|
||||||
usage.FiveHour.WindowStats = windowStats
|
|
||||||
}
|
}
|
||||||
|
usage.FiveHour.WindowStats = windowStatsFromAccountStats(stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
||||||
windowStats := windowStatsFromAccountStats(stats)
|
if usage.SevenDay == nil {
|
||||||
if hasMeaningfulWindowStats(windowStats) {
|
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||||
if usage.SevenDay == nil {
|
|
||||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
|
||||||
}
|
|
||||||
usage.SevenDay.WindowStats = windowStats
|
|
||||||
}
|
}
|
||||||
|
usage.SevenDay.WindowStats = windowStatsFromAccountStats(stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
@@ -992,13 +1067,6 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasMeaningfulWindowStats(stats *WindowStats) bool {
|
|
||||||
if stats == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
||||||
if len(extra) == 0 {
|
if len(extra) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -1055,6 +1123,11 @@ func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 窗口已过期(resetAt 在 now 之前)→ 额度已重置,归零
|
||||||
|
if progress.ResetsAt != nil && !now.Before(*progress.ResetsAt) {
|
||||||
|
progress.Utilization = 0
|
||||||
|
}
|
||||||
|
|
||||||
return progress
|
return progress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -148,3 +148,54 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
|
|||||||
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexUsageProgressFromExtra_ZerosExpiredWindow(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
now := time.Date(2026, 3, 16, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
t.Run("expired 5h window zeroes utilization", func(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"codex_5h_used_percent": 42.0,
|
||||||
|
"codex_5h_reset_at": "2026-03-16T10:00:00Z", // 2h ago
|
||||||
|
}
|
||||||
|
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
|
||||||
|
if progress == nil {
|
||||||
|
t.Fatal("expected non-nil progress")
|
||||||
|
}
|
||||||
|
if progress.Utilization != 0 {
|
||||||
|
t.Fatalf("expected Utilization=0 for expired window, got %v", progress.Utilization)
|
||||||
|
}
|
||||||
|
if progress.RemainingSeconds != 0 {
|
||||||
|
t.Fatalf("expected RemainingSeconds=0, got %v", progress.RemainingSeconds)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("active 5h window keeps utilization", func(t *testing.T) {
|
||||||
|
resetAt := now.Add(2 * time.Hour).Format(time.RFC3339)
|
||||||
|
extra := map[string]any{
|
||||||
|
"codex_5h_used_percent": 42.0,
|
||||||
|
"codex_5h_reset_at": resetAt,
|
||||||
|
}
|
||||||
|
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
|
||||||
|
if progress == nil {
|
||||||
|
t.Fatal("expected non-nil progress")
|
||||||
|
}
|
||||||
|
if progress.Utilization != 42.0 {
|
||||||
|
t.Fatalf("expected Utilization=42, got %v", progress.Utilization)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("expired 7d window zeroes utilization", func(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"codex_7d_used_percent": 88.0,
|
||||||
|
"codex_7d_reset_at": "2026-03-15T00:00:00Z", // yesterday
|
||||||
|
}
|
||||||
|
progress := buildCodexUsageProgressFromExtra(extra, "7d", now)
|
||||||
|
if progress == nil {
|
||||||
|
t.Fatal("expected non-nil progress")
|
||||||
|
}
|
||||||
|
if progress.Utilization != 0 {
|
||||||
|
t.Fatalf("expected Utilization=0 for expired 7d window, got %v", progress.Utilization)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1530,7 +1530,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if len(input.Credentials) > 0 {
|
if len(input.Credentials) > 0 {
|
||||||
account.Credentials = input.Credentials
|
account.Credentials = input.Credentials
|
||||||
}
|
}
|
||||||
if len(input.Extra) > 0 {
|
// Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。
|
||||||
|
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
|
||||||
|
if input.Extra != nil {
|
||||||
// 保留配额用量字段,防止编辑账号时意外重置
|
// 保留配额用量字段,防止编辑账号时意外重置
|
||||||
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
||||||
if v, ok := account.Extra[key]; ok {
|
if v, ok := account.Extra[key]; ok {
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
|
|||||||
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
|
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
|
|||||||
panic("unexpected ExistsByName call")
|
panic("unexpected ExistsByName call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
|
||||||
panic("unexpected GetAccountCount call")
|
panic("unexpected GetAccountCount call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -121,3 +121,35 @@ func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testi
|
|||||||
_, exists := repo.account.Extra[modelRateLimitsKey]
|
_, exists := repo.account.Extra[modelRateLimitsKey]
|
||||||
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
|
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccount_EmptyExtraPayloadCanClearQuotaLimits(t *testing.T) {
|
||||||
|
accountID := int64(103)
|
||||||
|
repo := &updateAccountOveragesRepoStub{
|
||||||
|
account: &Account{
|
||||||
|
ID: accountID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"quota_limit": 100.0,
|
||||||
|
"quota_daily_limit": 10.0,
|
||||||
|
"quota_weekly_limit": 40.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||||
|
// 显式空对象:语义是“清空 extra 中的可配置键”(例如关闭配额限制)
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.NotNil(t, repo.account.Extra)
|
||||||
|
require.NotContains(t, repo.account.Extra, "quota_limit")
|
||||||
|
require.NotContains(t, repo.account.Extra, "quota_daily_limit")
|
||||||
|
require.NotContains(t, repo.account.Extra, "quota_weekly_limit")
|
||||||
|
require.Len(t, repo.account.Extra, 0)
|
||||||
|
}
|
||||||
|
|||||||
@@ -930,7 +930,7 @@ func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParam
|
|||||||
case ErrorPolicyTempUnscheduled:
|
case ErrorPolicyTempUnscheduled:
|
||||||
slog.Info("temp_unschedulable_matched",
|
slog.Info("temp_unschedulable_matched",
|
||||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||||
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, IsStickySession: p.isStickySession}
|
||||||
}
|
}
|
||||||
return false, statusCode, nil
|
return false, statusCode, nil
|
||||||
}
|
}
|
||||||
@@ -1001,8 +1001,9 @@ type TestConnectionResult struct {
|
|||||||
MappedModel string // 实际使用的模型
|
MappedModel string // 实际使用的模型
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
// TestConnection 测试 Antigravity 账号连接。
|
||||||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
// 复用 antigravityRetryLoop 的完整重试 / credits overages / 智能重试逻辑,
|
||||||
|
// 与真实调度行为一致。差异:不做账号切换(测试指定账号)、不记录 ops 错误。
|
||||||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||||
|
|
||||||
// 获取 token
|
// 获取 token
|
||||||
@@ -1026,10 +1027,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
// 构建请求体
|
// 构建请求体
|
||||||
var requestBody []byte
|
var requestBody []byte
|
||||||
if strings.HasPrefix(modelID, "gemini-") {
|
if strings.HasPrefix(modelID, "gemini-") {
|
||||||
// Gemini 模型:直接使用 Gemini 格式
|
|
||||||
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
||||||
} else {
|
} else {
|
||||||
// Claude 模型:使用协议转换
|
|
||||||
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1042,64 +1041,63 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := resolveAntigravityForwardBaseURL()
|
// 复用 antigravityRetryLoop:完整的重试 / credits overages / 智能重试
|
||||||
if baseURL == "" {
|
prefix := fmt.Sprintf("[antigravity-Test] account=%d(%s)", account.ID, account.Name)
|
||||||
return nil, errors.New("no antigravity forward base url configured")
|
p := antigravityRetryLoopParams{
|
||||||
}
|
ctx: ctx,
|
||||||
availableURLs := []string{baseURL}
|
prefix: prefix,
|
||||||
|
account: account,
|
||||||
var lastErr error
|
proxyURL: proxyURL,
|
||||||
for urlIdx, baseURL := range availableURLs {
|
accessToken: accessToken,
|
||||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
action: "streamGenerateContent",
|
||||||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
body: requestBody,
|
||||||
if err != nil {
|
c: nil, // 无 gin.Context → 跳过 ops 追踪
|
||||||
lastErr = err
|
httpUpstream: s.httpUpstream,
|
||||||
continue
|
settingService: s.settingService,
|
||||||
}
|
accountRepo: s.accountRepo,
|
||||||
|
requestedModel: modelID,
|
||||||
// 调试日志:Test 请求信息
|
handleError: testConnectionHandleError,
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
lastErr = fmt.Errorf("请求失败: %w", err)
|
|
||||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, lastErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// 读取响应
|
|
||||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否需要 URL 降级
|
|
||||||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析流式响应,提取文本
|
|
||||||
text := extractTextFromSSEResponse(respBody)
|
|
||||||
|
|
||||||
// 标记成功的 URL,下次优先使用
|
|
||||||
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
|
|
||||||
return &TestConnectionResult{
|
|
||||||
Text: text,
|
|
||||||
MappedModel: mappedModel,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, lastErr
|
result, err := s.antigravityRetryLoop(p)
|
||||||
|
if err != nil {
|
||||||
|
// AccountSwitchError → 测试时不切换账号,返回友好提示
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
if errors.As(err, &switchErr) {
|
||||||
|
return nil, fmt.Errorf("该账号模型 %s 当前限流中,请稍后重试", switchErr.RateLimitedModel)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil || result.resp == nil {
|
||||||
|
return nil, errors.New("upstream returned empty response")
|
||||||
|
}
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(io.LimitReader(result.resp.Body, 2<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("API 返回 %d: %s", result.resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
text := extractTextFromSSEResponse(respBody)
|
||||||
|
return &TestConnectionResult{Text: text, MappedModel: mappedModel}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnectionHandleError 是 TestConnection 使用的轻量 handleError 回调。
|
||||||
|
// 仅记录日志,不做 ops 错误追踪或粘性会话清除。
|
||||||
|
func testConnectionHandleError(
|
||||||
|
_ context.Context, prefix string, account *Account,
|
||||||
|
statusCode int, _ http.Header, body []byte,
|
||||||
|
requestedModel string, _ int64, _ string, _ bool,
|
||||||
|
) *handleModelRateLimitResult {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway",
|
||||||
|
"%s test_handle_error status=%d model=%s account=%d body=%s",
|
||||||
|
prefix, statusCode, requestedModel, account.ID, truncateForLog(body, 200))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||||
@@ -3079,6 +3077,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||||
@@ -3111,6 +3125,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
return nil, ev.err
|
return nil, ev.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
line := ev.line
|
line := ev.line
|
||||||
trimmed := strings.TrimRight(line, "\r\n")
|
trimmed := strings.TrimRight(line, "\r\n")
|
||||||
if strings.HasPrefix(trimmed, "data:") {
|
if strings.HasPrefix(trimmed, "data:") {
|
||||||
@@ -3170,6 +3186,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if cw.Disconnected() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping/keepalive:保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if !cw.Fprintf(":\n\n") {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity gemini), continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3895,6 +3924,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||||
@@ -3947,6 +3992,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
// 处理 SSE 行,转换为 Claude 格式
|
// 处理 SSE 行,转换为 Claude 格式
|
||||||
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
||||||
if len(claudeEvents) > 0 {
|
if len(claudeEvents) > 0 {
|
||||||
@@ -3969,6 +4016,20 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if cw.Disconnected() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||||
|
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity claude), continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4299,6 +4360,22 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
flusher, _ := c.Writer.(http.Flusher)
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
||||||
|
|
||||||
@@ -4316,6 +4393,8 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
|||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
line := ev.line
|
line := ev.line
|
||||||
|
|
||||||
// 记录首 token 时间
|
// 记录首 token 时间
|
||||||
@@ -4341,6 +4420,20 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
|
|||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
|
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if cw.Disconnected() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
|
||||||
|
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
|
||||||
|
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
|
||||||
|
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity upstream), continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
|||||||
expected: "claude-opus-4-6-thinking",
|
expected: "claude-opus-4-6-thinking",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
|
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6",
|
||||||
requestedModel: "claude-haiku-4-5",
|
requestedModel: "claude-haiku-4-5",
|
||||||
accountMapping: nil,
|
accountMapping: nil,
|
||||||
expected: "claude-sonnet-4-5",
|
expected: "claude-sonnet-4-6",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
|
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6",
|
||||||
requestedModel: "claude-haiku-4-5-20251001",
|
requestedModel: "claude-haiku-4-5-20251001",
|
||||||
accountMapping: nil,
|
accountMapping: nil,
|
||||||
expected: "claude-sonnet-4-5",
|
expected: "claude-sonnet-4-6",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
|
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
|
||||||
|
|||||||
@@ -260,14 +260,15 @@ func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *test
|
|||||||
|
|
||||||
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
||||||
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
||||||
|
// 使用 RATE_LIMIT_EXCEEDED(走 1 次智能重试),避免 MODEL_CAPACITY_EXHAUSTED 的 60 次重试导致测试超时
|
||||||
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
||||||
// 智能重试也返回 503
|
// 智能重试也返回 503
|
||||||
failRespBody := `{
|
failRespBody := `{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 503,
|
||||||
"status": "UNAVAILABLE",
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
"details": [
|
"details": [
|
||||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -278,8 +279,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
|
|||||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||||
}
|
}
|
||||||
upstream := &mockSmartRetryUpstream{
|
upstream := &mockSmartRetryUpstream{
|
||||||
responses: []*http.Response{failResp},
|
responses: []*http.Response{failResp},
|
||||||
errors: []error{nil},
|
errors: []error{nil},
|
||||||
|
repeatLast: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := &stubAntigravityAccountRepo{}
|
repo := &stubAntigravityAccountRepo{}
|
||||||
@@ -294,9 +296,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
|
|||||||
respBody := []byte(`{
|
respBody := []byte(`{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 503,
|
||||||
"status": "UNAVAILABLE",
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
"details": [
|
"details": [
|
||||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -569,8 +571,9 @@ func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
|
|||||||
|
|
||||||
svc := &AntigravityGatewayService{}
|
svc := &AntigravityGatewayService{}
|
||||||
|
|
||||||
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
|
// waitDuration=0 会被 clamp 到 antigravitySmartRetryMinWait=1s。
|
||||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
|
// 首次重试即成功(200),总耗时 ~1s。
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 0, "gemini-3-pro")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
require.NotNil(t, result.resp)
|
require.NotNil(t, result.resp)
|
||||||
|
|||||||
@@ -32,11 +32,13 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID
|
|||||||
|
|
||||||
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
|
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
|
||||||
type mockSmartRetryUpstream struct {
|
type mockSmartRetryUpstream struct {
|
||||||
responses []*http.Response
|
responses []*http.Response
|
||||||
errors []error
|
responseBodies [][]byte // 缓存的 response body 字节(用于 repeatLast 重建)
|
||||||
callIdx int
|
errors []error
|
||||||
calls []string
|
callIdx int
|
||||||
requestBodies [][]byte
|
calls []string
|
||||||
|
requestBodies [][]byte
|
||||||
|
repeatLast bool // 超出范围时重复最后一个响应
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||||
@@ -50,10 +52,45 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
|
|||||||
m.requestBodies = append(m.requestBodies, nil)
|
m.requestBodies = append(m.requestBodies, nil)
|
||||||
}
|
}
|
||||||
m.callIdx++
|
m.callIdx++
|
||||||
if idx < len(m.responses) {
|
|
||||||
return m.responses[idx], m.errors[idx]
|
// 确定使用哪个索引
|
||||||
|
respIdx := idx
|
||||||
|
if respIdx >= len(m.responses) {
|
||||||
|
if !m.repeatLast || len(m.responses) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
respIdx = len(m.responses) - 1
|
||||||
}
|
}
|
||||||
return nil, nil
|
|
||||||
|
resp := m.responses[respIdx]
|
||||||
|
respErr := m.errors[respIdx]
|
||||||
|
if resp == nil {
|
||||||
|
return nil, respErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 首次调用时缓存 body 字节
|
||||||
|
if respIdx >= len(m.responseBodies) {
|
||||||
|
for len(m.responseBodies) <= respIdx {
|
||||||
|
m.responseBodies = append(m.responseBodies, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if m.responseBodies[respIdx] == nil && resp.Body != nil {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
m.responseBodies[respIdx] = bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用缓存的 body 字节重建新的 reader
|
||||||
|
var body io.ReadCloser
|
||||||
|
if m.responseBodies[respIdx] != nil {
|
||||||
|
body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
}, respErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ import (
|
|||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -84,17 +86,21 @@ type BackupScheduleConfig struct {
|
|||||||
|
|
||||||
// BackupRecord 备份记录
|
// BackupRecord 备份记录
|
||||||
type BackupRecord struct {
|
type BackupRecord struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Status string `json:"status"` // pending, running, completed, failed
|
Status string `json:"status"` // pending, running, completed, failed
|
||||||
BackupType string `json:"backup_type"` // postgres
|
BackupType string `json:"backup_type"` // postgres
|
||||||
FileName string `json:"file_name"`
|
FileName string `json:"file_name"`
|
||||||
S3Key string `json:"s3_key"`
|
S3Key string `json:"s3_key"`
|
||||||
SizeBytes int64 `json:"size_bytes"`
|
SizeBytes int64 `json:"size_bytes"`
|
||||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||||
ErrorMsg string `json:"error_message,omitempty"`
|
ErrorMsg string `json:"error_message,omitempty"`
|
||||||
StartedAt string `json:"started_at"`
|
StartedAt string `json:"started_at"`
|
||||||
FinishedAt string `json:"finished_at,omitempty"`
|
FinishedAt string `json:"finished_at,omitempty"`
|
||||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||||
|
Progress string `json:"progress,omitempty"` // "dumping", "uploading", ""
|
||||||
|
RestoreStatus string `json:"restore_status,omitempty"` // "", "running", "completed", "failed"
|
||||||
|
RestoreError string `json:"restore_error,omitempty"`
|
||||||
|
RestoredAt string `json:"restored_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackupService 数据库备份恢复服务
|
// BackupService 数据库备份恢复服务
|
||||||
@@ -105,17 +111,24 @@ type BackupService struct {
|
|||||||
storeFactory BackupObjectStoreFactory
|
storeFactory BackupObjectStoreFactory
|
||||||
dumper DBDumper
|
dumper DBDumper
|
||||||
|
|
||||||
mu sync.Mutex
|
opMu sync.Mutex // 保护 backingUp/restoring 标志
|
||||||
store BackupObjectStore
|
|
||||||
s3Cfg *BackupS3Config
|
|
||||||
backingUp bool
|
backingUp bool
|
||||||
restoring bool
|
restoring bool
|
||||||
|
|
||||||
|
storeMu sync.Mutex // 保护 store/s3Cfg 缓存
|
||||||
|
store BackupObjectStore
|
||||||
|
s3Cfg *BackupS3Config
|
||||||
|
|
||||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||||
|
|
||||||
cronMu sync.Mutex
|
cronMu sync.Mutex
|
||||||
cronSched *cron.Cron
|
cronSched *cron.Cron
|
||||||
cronEntryID cron.EntryID
|
cronEntryID cron.EntryID
|
||||||
|
|
||||||
|
wg sync.WaitGroup // 追踪活跃的备份/恢复 goroutine
|
||||||
|
shuttingDown atomic.Bool // 阻止新备份启动
|
||||||
|
bgCtx context.Context // 所有后台操作的 parent context
|
||||||
|
bgCancel context.CancelFunc // 取消所有活跃后台操作
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackupService(
|
func NewBackupService(
|
||||||
@@ -125,20 +138,26 @@ func NewBackupService(
|
|||||||
storeFactory BackupObjectStoreFactory,
|
storeFactory BackupObjectStoreFactory,
|
||||||
dumper DBDumper,
|
dumper DBDumper,
|
||||||
) *BackupService {
|
) *BackupService {
|
||||||
|
bgCtx, bgCancel := context.WithCancel(context.Background())
|
||||||
return &BackupService{
|
return &BackupService{
|
||||||
settingRepo: settingRepo,
|
settingRepo: settingRepo,
|
||||||
dbCfg: &cfg.Database,
|
dbCfg: &cfg.Database,
|
||||||
encryptor: encryptor,
|
encryptor: encryptor,
|
||||||
storeFactory: storeFactory,
|
storeFactory: storeFactory,
|
||||||
dumper: dumper,
|
dumper: dumper,
|
||||||
|
bgCtx: bgCtx,
|
||||||
|
bgCancel: bgCancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start 启动定时备份调度器
|
// Start 启动定时备份调度器并清理孤立记录
|
||||||
func (s *BackupService) Start() {
|
func (s *BackupService) Start() {
|
||||||
s.cronSched = cron.New()
|
s.cronSched = cron.New()
|
||||||
s.cronSched.Start()
|
s.cronSched.Start()
|
||||||
|
|
||||||
|
// 清理重启后孤立的 running 记录
|
||||||
|
s.recoverStaleRecords()
|
||||||
|
|
||||||
// 加载已有的定时配置
|
// 加载已有的定时配置
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -154,13 +173,65 @@ func (s *BackupService) Start() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop 停止定时备份
|
// recoverStaleRecords 启动时将孤立的 running 记录标记为 failed
|
||||||
|
func (s *BackupService) recoverStaleRecords() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
records, err := s.loadRecords(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range records {
|
||||||
|
if records[i].Status == "running" {
|
||||||
|
records[i].Status = "failed"
|
||||||
|
records[i].ErrorMsg = "interrupted by server restart"
|
||||||
|
records[i].Progress = ""
|
||||||
|
records[i].FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(ctx, &records[i])
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] recovered stale running record: %s", records[i].ID)
|
||||||
|
}
|
||||||
|
if records[i].RestoreStatus == "running" {
|
||||||
|
records[i].RestoreStatus = "failed"
|
||||||
|
records[i].RestoreError = "interrupted by server restart"
|
||||||
|
_ = s.saveRecord(ctx, &records[i])
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] recovered stale restoring record: %s", records[i].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止定时备份并等待活跃操作完成
|
||||||
func (s *BackupService) Stop() {
|
func (s *BackupService) Stop() {
|
||||||
|
s.shuttingDown.Store(true)
|
||||||
|
|
||||||
s.cronMu.Lock()
|
s.cronMu.Lock()
|
||||||
defer s.cronMu.Unlock()
|
|
||||||
if s.cronSched != nil {
|
if s.cronSched != nil {
|
||||||
s.cronSched.Stop()
|
s.cronSched.Stop()
|
||||||
}
|
}
|
||||||
|
s.cronMu.Unlock()
|
||||||
|
|
||||||
|
// 等待活跃备份/恢复完成(最多 5 分钟)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] all active operations finished")
|
||||||
|
case <-time.After(5 * time.Minute):
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] shutdown timeout after 5min, cancelling active operations")
|
||||||
|
if s.bgCancel != nil {
|
||||||
|
s.bgCancel() // 取消所有后台操作
|
||||||
|
}
|
||||||
|
// 给 goroutine 时间响应取消并完成清理
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] active operations cancelled and cleaned up")
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] goroutine cleanup timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── S3 配置管理 ───
|
// ─── S3 配置管理 ───
|
||||||
@@ -203,10 +274,10 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清除缓存的 S3 客户端
|
// 清除缓存的 S3 客户端
|
||||||
s.mu.Lock()
|
s.storeMu.Lock()
|
||||||
s.store = nil
|
s.store = nil
|
||||||
s.s3Cfg = nil
|
s.s3Cfg = nil
|
||||||
s.mu.Unlock()
|
s.storeMu.Unlock()
|
||||||
|
|
||||||
cfg.SecretAccessKey = ""
|
cfg.SecretAccessKey = ""
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
@@ -314,7 +385,10 @@ func (s *BackupService) removeCronSchedule() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackupService) runScheduledBackup() {
|
func (s *BackupService) runScheduledBackup() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
s.wg.Add(1)
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// 读取定时备份配置中的过期天数
|
// 读取定时备份配置中的过期天数
|
||||||
@@ -327,7 +401,11 @@ func (s *BackupService) runScheduledBackup() {
|
|||||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
if errors.Is(err, ErrBackupInProgress) {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份跳过: 已有备份正在进行中")
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||||
@@ -346,17 +424,21 @@ func (s *BackupService) runScheduledBackup() {
|
|||||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||||
s.mu.Lock()
|
if s.shuttingDown.Load() {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.opMu.Lock()
|
||||||
if s.backingUp {
|
if s.backingUp {
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
return nil, ErrBackupInProgress
|
return nil, ErrBackupInProgress
|
||||||
}
|
}
|
||||||
s.backingUp = true
|
s.backingUp = true
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
s.mu.Lock()
|
s.opMu.Lock()
|
||||||
s.backingUp = false
|
s.backingUp = false
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s3Cfg, err := s.loadS3Config(ctx)
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
@@ -405,36 +487,47 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
|||||||
|
|
||||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
var gzipErr error
|
gzipDone := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
|
||||||
|
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
gzWriter := gzip.NewWriter(pw)
|
gzWriter := gzip.NewWriter(pw)
|
||||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
var gzErr error
|
||||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
_, gzErr = io.Copy(gzWriter, dumpReader)
|
||||||
gzipErr = closeErr
|
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
|
||||||
|
gzErr = closeErr
|
||||||
}
|
}
|
||||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
|
||||||
gzipErr = closeErr
|
gzErr = closeErr
|
||||||
}
|
}
|
||||||
if gzipErr != nil {
|
if gzErr != nil {
|
||||||
_ = pw.CloseWithError(gzipErr)
|
_ = pw.CloseWithError(gzErr)
|
||||||
} else {
|
} else {
|
||||||
_ = pw.Close()
|
_ = pw.Close()
|
||||||
}
|
}
|
||||||
|
gzipDone <- gzErr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
contentType := "application/gzip"
|
contentType := "application/gzip"
|
||||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
|
||||||
|
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
|
||||||
record.Status = "failed"
|
record.Status = "failed"
|
||||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||||
if gzipErr != nil {
|
if gzErr != nil {
|
||||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
|
||||||
}
|
}
|
||||||
record.ErrorMsg = errMsg
|
record.ErrorMsg = errMsg
|
||||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
_ = s.saveRecord(ctx, record)
|
_ = s.saveRecord(ctx, record)
|
||||||
return record, fmt.Errorf("backup upload: %w", err)
|
return record, fmt.Errorf("backup upload: %w", err)
|
||||||
}
|
}
|
||||||
|
<-gzipDone // 确保 gzip goroutine 已退出
|
||||||
|
|
||||||
record.SizeBytes = sizeBytes
|
record.SizeBytes = sizeBytes
|
||||||
record.Status = "completed"
|
record.Status = "completed"
|
||||||
@@ -446,19 +539,187 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
|||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackup 异步创建备份,立即返回 running 状态的记录
|
||||||
|
func (s *BackupService) StartBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||||
|
if s.shuttingDown.Load() {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.opMu.Lock()
|
||||||
|
if s.backingUp {
|
||||||
|
s.opMu.Unlock()
|
||||||
|
return nil, ErrBackupInProgress
|
||||||
|
}
|
||||||
|
s.backingUp = true
|
||||||
|
s.opMu.Unlock()
|
||||||
|
|
||||||
|
// 初始化阶段出错时自动重置标志
|
||||||
|
launched := false
|
||||||
|
defer func() {
|
||||||
|
if !launched {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.backingUp = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 在返回前加载 S3 配置和创建 store,避免 goroutine 中配置被修改
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||||
|
return nil, ErrBackupS3NotConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init object store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
backupID := uuid.New().String()[:8]
|
||||||
|
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||||
|
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||||
|
|
||||||
|
var expiresAt string
|
||||||
|
if expireDays > 0 {
|
||||||
|
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &BackupRecord{
|
||||||
|
ID: backupID,
|
||||||
|
Status: "running",
|
||||||
|
BackupType: "postgres",
|
||||||
|
FileName: fileName,
|
||||||
|
S3Key: s3Key,
|
||||||
|
TriggeredBy: triggeredBy,
|
||||||
|
StartedAt: now.Format(time.RFC3339),
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Progress: "pending",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.saveRecord(ctx, record); err != nil {
|
||||||
|
return nil, fmt.Errorf("save initial record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
launched = true
|
||||||
|
// 在启动 goroutine 前完成拷贝,避免数据竞争
|
||||||
|
result := *record
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer func() {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.backingUp = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] panic recovered: %v", r)
|
||||||
|
record.Status = "failed"
|
||||||
|
record.ErrorMsg = fmt.Sprintf("internal panic: %v", r)
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s.executeBackup(record, objectStore)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeBackup 后台执行备份(独立于 HTTP context)
|
||||||
|
func (s *BackupService) executeBackup(record *BackupRecord, objectStore BackupObjectStore) {
|
||||||
|
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 阶段1: pg_dump
|
||||||
|
record.Progress = "dumping"
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
|
||||||
|
dumpReader, err := s.dumper.Dump(ctx)
|
||||||
|
if err != nil {
|
||||||
|
record.Status = "failed"
|
||||||
|
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 阶段2: gzip + upload
|
||||||
|
record.Progress = "uploading"
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
gzipDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
|
||||||
|
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
gzWriter := gzip.NewWriter(pw)
|
||||||
|
var gzErr error
|
||||||
|
_, gzErr = io.Copy(gzWriter, dumpReader)
|
||||||
|
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
|
||||||
|
gzErr = closeErr
|
||||||
|
}
|
||||||
|
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
|
||||||
|
gzErr = closeErr
|
||||||
|
}
|
||||||
|
if gzErr != nil {
|
||||||
|
_ = pw.CloseWithError(gzErr)
|
||||||
|
} else {
|
||||||
|
_ = pw.Close()
|
||||||
|
}
|
||||||
|
gzipDone <- gzErr
|
||||||
|
}()
|
||||||
|
|
||||||
|
contentType := "application/gzip"
|
||||||
|
sizeBytes, err := objectStore.Upload(ctx, record.S3Key, pr, contentType)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
|
||||||
|
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
|
||||||
|
record.Status = "failed"
|
||||||
|
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||||
|
if gzErr != nil {
|
||||||
|
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
|
||||||
|
}
|
||||||
|
record.ErrorMsg = errMsg
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
<-gzipDone // 确保 gzip goroutine 已退出
|
||||||
|
|
||||||
|
record.SizeBytes = sizeBytes
|
||||||
|
record.Status = "completed"
|
||||||
|
record.Progress = ""
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
if err := s.saveRecord(context.Background(), record); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||||
s.mu.Lock()
|
s.opMu.Lock()
|
||||||
if s.restoring {
|
if s.restoring {
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
return ErrRestoreInProgress
|
return ErrRestoreInProgress
|
||||||
}
|
}
|
||||||
s.restoring = true
|
s.restoring = true
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
s.mu.Lock()
|
s.opMu.Lock()
|
||||||
s.restoring = false
|
s.restoring = false
|
||||||
s.mu.Unlock()
|
s.opMu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
record, err := s.GetBackupRecord(ctx, backupID)
|
record, err := s.GetBackupRecord(ctx, backupID)
|
||||||
@@ -500,6 +761,112 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartRestore 异步恢复备份,立即返回
|
||||||
|
func (s *BackupService) StartRestore(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||||
|
if s.shuttingDown.Load() {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.opMu.Lock()
|
||||||
|
if s.restoring {
|
||||||
|
s.opMu.Unlock()
|
||||||
|
return nil, ErrRestoreInProgress
|
||||||
|
}
|
||||||
|
s.restoring = true
|
||||||
|
s.opMu.Unlock()
|
||||||
|
|
||||||
|
// 初始化阶段出错时自动重置标志
|
||||||
|
launched := false
|
||||||
|
defer func() {
|
||||||
|
if !launched {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.restoring = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
record, err := s.GetBackupRecord(ctx, backupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if record.Status != "completed" {
|
||||||
|
return nil, infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||||
|
}
|
||||||
|
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init object store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record.RestoreStatus = "running"
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
|
||||||
|
launched = true
|
||||||
|
result := *record
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer func() {
|
||||||
|
s.opMu.Lock()
|
||||||
|
s.restoring = false
|
||||||
|
s.opMu.Unlock()
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] restore panic recovered: %v", r)
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("internal panic: %v", r)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s.executeRestore(record, objectStore)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRestore 后台执行恢复
|
||||||
|
func (s *BackupService) executeRestore(record *BackupRecord, objectStore BackupObjectStore) {
|
||||||
|
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
body, err := objectStore.Download(ctx, record.S3Key)
|
||||||
|
if err != nil {
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("S3 download failed: %v", err)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = body.Close() }()
|
||||||
|
|
||||||
|
gzReader, err := gzip.NewReader(body)
|
||||||
|
if err != nil {
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("gzip reader: %v", err)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = gzReader.Close() }()
|
||||||
|
|
||||||
|
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||||
|
record.RestoreStatus = "failed"
|
||||||
|
record.RestoreError = fmt.Sprintf("pg restore: %v", err)
|
||||||
|
_ = s.saveRecord(context.Background(), record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record.RestoreStatus = "completed"
|
||||||
|
record.RestoredAt = time.Now().Format(time.RFC3339)
|
||||||
|
if err := s.saveRecord(context.Background(), record); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 保存恢复记录失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ─── 备份记录管理 ───
|
// ─── 备份记录管理 ───
|
||||||
|
|
||||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||||
@@ -614,8 +981,8 @@ func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||||
s.mu.Lock()
|
s.storeMu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.storeMu.Unlock()
|
||||||
|
|
||||||
if s.store != nil && s.s3Cfg != nil {
|
if s.store != nil && s.s3Cfg != nil {
|
||||||
return s.store, nil
|
return s.store, nil
|
||||||
|
|||||||
@@ -134,6 +134,30 @@ func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// blockingDumper 可控延迟的 dumper,用于测试异步行为
|
||||||
|
type blockingDumper struct {
|
||||||
|
blockCh chan struct{}
|
||||||
|
data []byte
|
||||||
|
restErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *blockingDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||||
|
select {
|
||||||
|
case <-d.blockCh:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(d.data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *blockingDumper) Restore(_ context.Context, data io.Reader) error {
|
||||||
|
if d.restErr != nil {
|
||||||
|
return d.restErr
|
||||||
|
}
|
||||||
|
_, _ = io.ReadAll(data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type mockObjectStore struct {
|
type mockObjectStore struct {
|
||||||
objects map[string][]byte
|
objects map[string][]byte
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -179,7 +203,7 @@ func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
func newTestBackupService(repo *mockSettingRepo, dumper DBDumper, store *mockObjectStore) *BackupService {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Database: config.DatabaseConfig{
|
Database: config.DatabaseConfig{
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
@@ -361,9 +385,9 @@ func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
|||||||
svc := newTestBackupService(repo, dumper, store)
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
// 手动设置 backingUp 标志
|
// 手动设置 backingUp 标志
|
||||||
svc.mu.Lock()
|
svc.opMu.Lock()
|
||||||
svc.backingUp = true
|
svc.backingUp = true
|
||||||
svc.mu.Unlock()
|
svc.opMu.Unlock()
|
||||||
|
|
||||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||||
@@ -526,3 +550,154 @@ func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, cfg)
|
require.Nil(t, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Async Backup Tests ───
|
||||||
|
|
||||||
|
func TestStartBackup_ReturnsImmediately(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
record, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "running", record.Status)
|
||||||
|
require.NotEmpty(t, record.ID)
|
||||||
|
|
||||||
|
// 释放 dumper 让后台完成
|
||||||
|
close(dumper.blockCh)
|
||||||
|
svc.wg.Wait()
|
||||||
|
|
||||||
|
// 验证最终状态
|
||||||
|
final, err := svc.GetBackupRecord(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "completed", final.Status)
|
||||||
|
require.Greater(t, final.SizeBytes, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartBackup_ConcurrentBlocked(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
// 第一次启动
|
||||||
|
_, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 第二次应被阻塞
|
||||||
|
_, err = svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||||
|
|
||||||
|
close(dumper.blockCh)
|
||||||
|
svc.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartBackup_ShuttingDown(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{dumpData: []byte("data")}, newMockObjectStore())
|
||||||
|
|
||||||
|
svc.shuttingDown.Store(true)
|
||||||
|
|
||||||
|
_, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoverStaleRecords(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
// 模拟一条孤立的 running 记录
|
||||||
|
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||||
|
ID: "stale-1",
|
||||||
|
Status: "running",
|
||||||
|
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
// 模拟一条孤立的恢复中记录
|
||||||
|
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||||
|
ID: "stale-2",
|
||||||
|
Status: "completed",
|
||||||
|
RestoreStatus: "running",
|
||||||
|
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
|
||||||
|
svc.recoverStaleRecords()
|
||||||
|
|
||||||
|
r1, _ := svc.GetBackupRecord(context.Background(), "stale-1")
|
||||||
|
require.Equal(t, "failed", r1.Status)
|
||||||
|
require.Contains(t, r1.ErrorMsg, "server restart")
|
||||||
|
|
||||||
|
r2, _ := svc.GetBackupRecord(context.Background(), "stale-2")
|
||||||
|
require.Equal(t, "failed", r2.RestoreStatus)
|
||||||
|
require.Contains(t, r2.RestoreError, "server restart")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdown(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
_, err := svc.StartBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Stop 应该等待备份完成
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
svc.Stop()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 短暂等待确认 Stop 还在等待
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Fatal("Stop returned before backup finished")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// 预期:Stop 还在等待
|
||||||
|
}
|
||||||
|
|
||||||
|
// 释放备份
|
||||||
|
close(dumper.blockCh)
|
||||||
|
|
||||||
|
// 现在 Stop 应该完成
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 预期
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Stop did not return after backup finished")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartRestore_Async(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||||
|
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
// 先创建一个备份(同步方式)
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 异步恢复
|
||||||
|
restored, err := svc.StartRestore(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "running", restored.RestoreStatus)
|
||||||
|
|
||||||
|
svc.wg.Wait()
|
||||||
|
|
||||||
|
// 验证最终状态
|
||||||
|
final, err := svc.GetBackupRecord(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "completed", final.RestoreStatus)
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,9 +21,6 @@ var (
|
|||||||
// 带捕获组的版本提取正则
|
// 带捕获组的版本提取正则
|
||||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
||||||
|
|
||||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
|
||||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
|
||||||
|
|
||||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||||
systemPromptThreshold = 0.5
|
systemPromptThreshold = 0.5
|
||||||
)
|
)
|
||||||
@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userIDPattern.MatchString(userID) {
|
if ParseMetadataUserID(userID) == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
|
|||||||
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
|
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
|
||||||
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
|
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
|
||||||
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
|
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
|
||||||
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
|
return ExtractCLIVersion(ua)
|
||||||
if len(matches) >= 2 {
|
|
||||||
return matches[1]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
|
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
|
||||||
|
|||||||
@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
|
||||||
|
normalizedSource := usagestats.NormalizeModelSource(modelSource)
|
||||||
|
if normalizedSource == usagestats.ModelSourceRequested {
|
||||||
|
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelStatsBySourceRepo interface {
|
||||||
|
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
|
||||||
|
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
|
||||||
|
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get group usage summary: %w", err)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||||
data, err := s.cache.GetDashboardStats(ctx)
|
data, err := s.cache.GetDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -335,6 +365,14 @@ func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime
|
|||||||
return ranking, nil
|
return ranking, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
stats, err := s.usageRepo.GetUserBreakdownStats(ctx, startTime, endTime, dim, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get user breakdown stats: %w", err)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -170,6 +170,13 @@ const (
|
|||||||
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
|
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
|
||||||
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
|
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
|
||||||
|
|
||||||
|
// =========================
|
||||||
|
// Overload Cooldown (529)
|
||||||
|
// =========================
|
||||||
|
|
||||||
|
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
|
||||||
|
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
|
||||||
|
|
||||||
// =========================
|
// =========================
|
||||||
// Stream Timeout Handling
|
// Stream Timeout Handling
|
||||||
// =========================
|
// =========================
|
||||||
|
|||||||
@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
|
|||||||
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "system array",
|
||||||
|
body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system string",
|
||||||
|
body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"x-request-id": []string{"rid-oauth-preserve"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
deferredService: &DeferredService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Name: "anthropic-oauth-preserve",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
|
||||||
|
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
|
||||||
|
|
||||||
|
system := gjson.GetBytes(upstream.lastBody, "system")
|
||||||
|
require.True(t, system.Exists())
|
||||||
|
require.Contains(t, system.Raw, "x-anthropic-billing-header keep")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@@ -788,7 +865,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
|
|||||||
rateLimitService: &RateLimitService{},
|
rateLimitService: &RateLimitService{},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now())
|
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, 12, result.Usage.InputTokens)
|
require.Equal(t, 12, result.Usage.InputTokens)
|
||||||
@@ -815,7 +892,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
|
|||||||
}
|
}
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now())
|
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "requires apikey token")
|
require.Contains(t, err.Error(), "requires apikey token")
|
||||||
@@ -840,7 +917,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
|
|||||||
}
|
}
|
||||||
account := newAnthropicAPIKeyAccountForTest()
|
account := newAnthropicAPIKeyAccountForTest()
|
||||||
|
|
||||||
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now())
|
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "upstream request failed")
|
require.Contains(t, err.Error(), "upstream request failed")
|
||||||
@@ -873,7 +950,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
|
|||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now())
|
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "empty response")
|
require.Contains(t, err.Error(), "empty response")
|
||||||
|
|||||||
72
backend/internal/service/gateway_body_order_test.go
Normal file
72
backend/internal/service/gateway_body_order_test.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
last := -1
|
||||||
|
for _, token := range tokens {
|
||||||
|
pos := strings.Index(body, token)
|
||||||
|
require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body)
|
||||||
|
require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body)
|
||||||
|
last = pos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`)
|
||||||
|
|
||||||
|
result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022")
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`)
|
||||||
|
require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`)
|
||||||
|
|
||||||
|
result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{
|
||||||
|
injectMetadata: true,
|
||||||
|
metadataUserID: "user-1",
|
||||||
|
})
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
|
||||||
|
require.NotContains(t, resultStr, `"temperature"`)
|
||||||
|
require.NotContains(t, resultStr, `"tool_choice"`)
|
||||||
|
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
|
||||||
|
require.Contains(t, resultStr, `"tools":[]`)
|
||||||
|
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`)
|
||||||
|
|
||||||
|
result := injectClaudeCodePrompt(body, []any{
|
||||||
|
map[string]any{"id": "block-1", "type": "text", "text": "Custom"},
|
||||||
|
})
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||||
|
require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`)
|
||||||
|
|
||||||
|
result := enforceCacheControlLimit(body)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||||
|
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||||
|
}
|
||||||
34
backend/internal/service/gateway_debug_env_test.go
Normal file
34
backend/internal/service/gateway_debug_env_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
|
||||||
|
t.Run("default disabled", func(t *testing.T) {
|
||||||
|
t.Setenv(debugGatewayBodyEnv, "")
|
||||||
|
if debugGatewayBodyLoggingEnabled() {
|
||||||
|
t.Fatalf("expected debug gateway body logging to be disabled by default")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("enabled with true-like values", func(t *testing.T) {
|
||||||
|
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
|
||||||
|
t.Run(value, func(t *testing.T) {
|
||||||
|
t.Setenv(debugGatewayBodyEnv, value)
|
||||||
|
if !debugGatewayBodyLoggingEnabled() {
|
||||||
|
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("disabled with other values", func(t *testing.T) {
|
||||||
|
for _, value := range []string{"0", "false", "off", "debug"} {
|
||||||
|
t.Run(value, func(t *testing.T) {
|
||||||
|
t.Setenv(debugGatewayBodyEnv, value)
|
||||||
|
if debugGatewayBodyLoggingEnabled() {
|
||||||
|
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
|
|||||||
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@@ -28,6 +28,12 @@ var (
|
|||||||
patternEmptyContentSpaced = []byte(`"content": []`)
|
patternEmptyContentSpaced = []byte(`"content": []`)
|
||||||
patternEmptyContentSp1 = []byte(`"content" : []`)
|
patternEmptyContentSp1 = []byte(`"content" : []`)
|
||||||
patternEmptyContentSp2 = []byte(`"content" :[]`)
|
patternEmptyContentSp2 = []byte(`"content" :[]`)
|
||||||
|
|
||||||
|
// Fast-path patterns for empty text blocks: {"type":"text","text":""}
|
||||||
|
patternEmptyText = []byte(`"text":""`)
|
||||||
|
patternEmptyTextSpaced = []byte(`"text": ""`)
|
||||||
|
patternEmptyTextSp1 = []byte(`"text" : ""`)
|
||||||
|
patternEmptyTextSp2 = []byte(`"text" :""`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||||
@@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
|||||||
bytes.Contains(body, patternThinkingField) ||
|
bytes.Contains(body, patternThinkingField) ||
|
||||||
bytes.Contains(body, patternThinkingFieldSpaced)
|
bytes.Contains(body, patternThinkingFieldSpaced)
|
||||||
|
|
||||||
// Also check for empty content arrays that need fixing.
|
// Also check for empty content arrays and empty text blocks that need fixing.
|
||||||
// Note: This is a heuristic check; the actual empty content handling is done below.
|
// Note: This is a heuristic check; the actual empty content handling is done below.
|
||||||
hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
|
hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
|
||||||
bytes.Contains(body, patternEmptyContentSpaced) ||
|
bytes.Contains(body, patternEmptyContentSpaced) ||
|
||||||
bytes.Contains(body, patternEmptyContentSp1) ||
|
bytes.Contains(body, patternEmptyContentSp1) ||
|
||||||
bytes.Contains(body, patternEmptyContentSp2)
|
bytes.Contains(body, patternEmptyContentSp2)
|
||||||
|
|
||||||
|
// Check for empty text blocks: {"type":"text","text":""}
|
||||||
|
// These cause upstream 400: "text content blocks must be non-empty"
|
||||||
|
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
|
||||||
|
bytes.Contains(body, patternEmptyTextSpaced) ||
|
||||||
|
bytes.Contains(body, patternEmptyTextSp1) ||
|
||||||
|
bytes.Contains(body, patternEmptyTextSp2)
|
||||||
|
|
||||||
// Fast path: nothing to process
|
// Fast path: nothing to process
|
||||||
if !hasThinkingContent && !hasEmptyContent {
|
if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
|||||||
bytes.Contains(body, patternTypeRedactedThinking) ||
|
bytes.Contains(body, patternTypeRedactedThinking) ||
|
||||||
bytes.Contains(body, patternTypeRedactedSpaced) ||
|
bytes.Contains(body, patternTypeRedactedSpaced) ||
|
||||||
bytes.Contains(body, patternThinkingFieldSpaced)
|
bytes.Contains(body, patternThinkingFieldSpaced)
|
||||||
if !hasEmptyContent && !containsThinkingBlocks {
|
if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks {
|
||||||
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
|
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
|
||||||
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
|
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
|
||||||
out = removeThinkingDependentContextStrategies(out)
|
out = removeThinkingDependentContextStrategies(out)
|
||||||
@@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
|||||||
|
|
||||||
blockType, _ := blockMap["type"].(string)
|
blockType, _ := blockMap["type"].(string)
|
||||||
|
|
||||||
|
// Strip empty text blocks: {"type":"text","text":""}
|
||||||
|
// Upstream rejects these with 400: "text content blocks must be non-empty"
|
||||||
|
if blockType == "text" {
|
||||||
|
if txt, _ := blockMap["text"].(string); txt == "" {
|
||||||
|
modifiedThisMsg = true
|
||||||
|
ensureNewContent(bi)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
|
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
|
||||||
switch blockType {
|
switch blockType {
|
||||||
case "thinking":
|
case "thinking":
|
||||||
|
|||||||
@@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T)
|
|||||||
require.NotEmpty(t, content0["text"])
|
require.NotEmpty(t, content0["text"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
|
||||||
|
// Empty text blocks cause upstream 400: "text content blocks must be non-empty"
|
||||||
|
input := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]},
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":""}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out := FilterThinkingBlocksForRetry(input)
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(out, &req))
|
||||||
|
msgs, ok := req["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// First message: empty text block stripped, "hello" preserved
|
||||||
|
msg0 := msgs[0].(map[string]any)
|
||||||
|
content0 := msg0["content"].([]any)
|
||||||
|
require.Len(t, content0, 1)
|
||||||
|
require.Equal(t, "hello", content0[0].(map[string]any)["text"])
|
||||||
|
|
||||||
|
// Second message: only had empty text block → gets placeholder
|
||||||
|
msg1 := msgs[1].(map[string]any)
|
||||||
|
content1 := msg1["content"].([]any)
|
||||||
|
require.Len(t, content1, 1)
|
||||||
|
block1 := content1[0].(map[string]any)
|
||||||
|
require.Equal(t, "text", block1["type"])
|
||||||
|
require.NotEmpty(t, block1["text"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
|
||||||
|
// Non-empty text blocks should pass through unchanged
|
||||||
|
input := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out := FilterThinkingBlocksForRetry(input)
|
||||||
|
|
||||||
|
// Fast path: no thinking content, no empty content, no empty text blocks → unchanged
|
||||||
|
require.Equal(t, input, out)
|
||||||
|
}
|
||||||
|
|
||||||
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||||
input := []byte(`{
|
input := []byte(`{
|
||||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
|
|||||||
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
|
|||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
parsed := &ParsedRequest{
|
parsed := &ParsedRequest{
|
||||||
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
|
||||||
System: "You are a helpful assistant.",
|
System: "You are a helpful assistant.",
|
||||||
HasSystem: true,
|
HasSystem: true,
|
||||||
Messages: []any{
|
Messages: []any{
|
||||||
@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
|
|||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
parsed := &ParsedRequest{
|
parsed := &ParsedRequest{
|
||||||
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
|
||||||
Messages: []any{
|
Messages: []any{
|
||||||
map[string]any{"role": "user", "content": "hello"},
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
},
|
},
|
||||||
@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
|
|||||||
"metadata session_id should take priority over SessionContext")
|
"metadata session_id should take priority over SessionContext")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
|
||||||
|
System: "You are a helpful assistant.",
|
||||||
|
HasSystem: true,
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := svc.GenerateSessionHash(parsed)
|
||||||
|
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
|
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
|||||||
@@ -64,8 +64,10 @@ type Group struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
AccountGroups []AccountGroup
|
AccountGroups []AccountGroup
|
||||||
AccountCount int64
|
AccountCount int64
|
||||||
|
ActiveAccountCount int64
|
||||||
|
RateLimitedAccountCount int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) IsActive() bool {
|
func (g *Group) IsActive() bool {
|
||||||
|
|||||||
131
backend/internal/service/group_capacity_service.go
Normal file
131
backend/internal/service/group_capacity_service.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupCapacitySummary holds aggregated capacity for a single group.
|
||||||
|
type GroupCapacitySummary struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
ConcurrencyUsed int `json:"concurrency_used"`
|
||||||
|
ConcurrencyMax int `json:"concurrency_max"`
|
||||||
|
SessionsUsed int `json:"sessions_used"`
|
||||||
|
SessionsMax int `json:"sessions_max"`
|
||||||
|
RPMUsed int `json:"rpm_used"`
|
||||||
|
RPMMax int `json:"rpm_max"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupCapacityService aggregates per-group capacity from runtime data.
|
||||||
|
type GroupCapacityService struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
groupRepo GroupRepository
|
||||||
|
concurrencyService *ConcurrencyService
|
||||||
|
sessionLimitCache SessionLimitCache
|
||||||
|
rpmCache RPMCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGroupCapacityService creates a new GroupCapacityService.
|
||||||
|
func NewGroupCapacityService(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
groupRepo GroupRepository,
|
||||||
|
concurrencyService *ConcurrencyService,
|
||||||
|
sessionLimitCache SessionLimitCache,
|
||||||
|
rpmCache RPMCache,
|
||||||
|
) *GroupCapacityService {
|
||||||
|
return &GroupCapacityService{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
concurrencyService: concurrencyService,
|
||||||
|
sessionLimitCache: sessionLimitCache,
|
||||||
|
rpmCache: rpmCache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllGroupCapacity returns capacity summary for all active groups.
|
||||||
|
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
|
||||||
|
groups, err := s.groupRepo.ListActive(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]GroupCapacitySummary, 0, len(groups))
|
||||||
|
for i := range groups {
|
||||||
|
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
|
||||||
|
if err != nil {
|
||||||
|
// Skip groups with errors, return partial results
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cap.GroupID = groups[i].ID
|
||||||
|
results = append(results, cap)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
|
||||||
|
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return GroupCapacitySummary{}, err
|
||||||
|
}
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return GroupCapacitySummary{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect account IDs and config values
|
||||||
|
accountIDs := make([]int64, 0, len(accounts))
|
||||||
|
sessionTimeouts := make(map[int64]time.Duration)
|
||||||
|
var concurrencyMax, sessionsMax, rpmMax int
|
||||||
|
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
accountIDs = append(accountIDs, acc.ID)
|
||||||
|
concurrencyMax += acc.Concurrency
|
||||||
|
|
||||||
|
if ms := acc.GetMaxSessions(); ms > 0 {
|
||||||
|
sessionsMax += ms
|
||||||
|
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 5 * time.Minute
|
||||||
|
}
|
||||||
|
sessionTimeouts[acc.ID] = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if rpm := acc.GetBaseRPM(); rpm > 0 {
|
||||||
|
rpmMax += rpm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch query runtime data from Redis
|
||||||
|
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||||
|
|
||||||
|
var sessionsMap map[int64]int
|
||||||
|
if sessionsMax > 0 && s.sessionLimitCache != nil {
|
||||||
|
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rpmMap map[int64]int
|
||||||
|
if rpmMax > 0 && s.rpmCache != nil {
|
||||||
|
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate
|
||||||
|
var concurrencyUsed, sessionsUsed, rpmUsed int
|
||||||
|
for _, id := range accountIDs {
|
||||||
|
concurrencyUsed += concurrencyMap[id]
|
||||||
|
if sessionsMap != nil {
|
||||||
|
sessionsUsed += sessionsMap[id]
|
||||||
|
}
|
||||||
|
if rpmMap != nil {
|
||||||
|
rpmUsed += rpmMap[id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return GroupCapacitySummary{
|
||||||
|
ConcurrencyUsed: concurrencyUsed,
|
||||||
|
ConcurrencyMax: concurrencyMax,
|
||||||
|
SessionsUsed: sessionsUsed,
|
||||||
|
SessionsMax: sessionsMax,
|
||||||
|
RPMUsed: rpmUsed,
|
||||||
|
RPMMax: rpmMax,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -27,7 +27,7 @@ type GroupRepository interface {
|
|||||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||||
|
|
||||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
|
||||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||||
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
||||||
@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取账号数量
|
// 获取账号数量
|
||||||
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
|
accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get account count: %w", err)
|
return nil, fmt.Errorf("get account count: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,14 +14,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
var (
|
var (
|
||||||
// 匹配 user_id 格式:
|
|
||||||
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
|
|
||||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
|
|
||||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
|
|
||||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||||
)
|
)
|
||||||
@@ -209,67 +206,57 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RewriteUserID 重写body中的metadata.user_id
|
// RewriteUserID 重写body中的metadata.user_id
|
||||||
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
// 支持旧拼接格式和新 JSON 格式的 user_id 解析,
|
||||||
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
// 根据 fingerprintUA 版本选择输出格式。
|
||||||
//
|
//
|
||||||
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||||
// 避免重新序列化导致 thinking 块等内容被修改。
|
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||||
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
|
||||||
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 RawMessage 保留其他字段的原始字节
|
metadata := gjson.GetBytes(body, "metadata")
|
||||||
var reqMap map[string]json.RawMessage
|
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
return body, nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 metadata 字段
|
userIDResult := metadata.Get("user_id")
|
||||||
metadataRaw, ok := reqMap["metadata"]
|
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
|
||||||
if !ok {
|
return body, nil
|
||||||
|
}
|
||||||
|
userID := userIDResult.String()
|
||||||
|
if userID == "" {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var metadata map[string]any
|
// 解析 user_id(兼容旧拼接格式和新 JSON 格式)
|
||||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
parsed := ParseMetadataUserID(userID)
|
||||||
|
if parsed == nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, ok := metadata["user_id"].(string)
|
sessionTail := parsed.SessionID // 原始session UUID
|
||||||
if !ok || userID == "" {
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 匹配格式:
|
|
||||||
// 旧格式: user_{64位hex}_account__session_{uuid}
|
|
||||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
|
|
||||||
matches := userIDRegex.FindStringSubmatch(userID)
|
|
||||||
if matches == nil {
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// matches[1] = account UUID (可能为空), matches[2] = session UUID
|
|
||||||
sessionTail := matches[2] // 原始session UUID
|
|
||||||
|
|
||||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||||
newSessionHash := generateUUIDFromSeed(seed)
|
newSessionHash := generateUUIDFromSeed(seed)
|
||||||
|
|
||||||
// 构建新的user_id
|
// 根据客户端版本选择输出格式
|
||||||
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
|
version := ExtractCLIVersion(fingerprintUA)
|
||||||
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
|
||||||
|
if newUserID == userID {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID)
|
||||||
|
|
||||||
// 只重新序列化 metadata 字段
|
|
||||||
newMetadataRaw, err := json.Marshal(metadata)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
reqMap["metadata"] = newMetadataRaw
|
return newBody, nil
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||||
@@ -278,9 +265,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
//
|
//
|
||||||
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||||
// 避免重新序列化导致 thinking 块等内容被修改。
|
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
|
||||||
// 先执行常规的 RewriteUserID 逻辑
|
// 先执行常规的 RewriteUserID 逻辑
|
||||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newBody, err
|
return newBody, err
|
||||||
}
|
}
|
||||||
@@ -290,32 +277,26 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 RawMessage 保留其他字段的原始字节
|
metadata := gjson.GetBytes(newBody, "metadata")
|
||||||
var reqMap map[string]json.RawMessage
|
if !metadata.Exists() || metadata.Type == gjson.Null {
|
||||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
return newBody, nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 metadata 字段
|
userIDResult := metadata.Get("user_id")
|
||||||
metadataRaw, ok := reqMap["metadata"]
|
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
|
||||||
if !ok {
|
return newBody, nil
|
||||||
|
}
|
||||||
|
userID := userIDResult.String()
|
||||||
|
if userID == "" {
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var metadata map[string]any
|
// 解析已重写的 user_id
|
||||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
uidParsed := ParseMetadataUserID(userID)
|
||||||
return newBody, nil
|
if uidParsed == nil {
|
||||||
}
|
|
||||||
|
|
||||||
userID, ok := metadata["user_id"].(string)
|
|
||||||
if !ok || userID == "" {
|
|
||||||
return newBody, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查找 _session_ 的位置,替换其后的内容
|
|
||||||
const sessionMarker = "_session_"
|
|
||||||
idx := strings.LastIndex(userID, sessionMarker)
|
|
||||||
if idx == -1 {
|
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,8 +318,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
|
// 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
|
||||||
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
|
version := ExtractCLIVersion(fingerprintUA)
|
||||||
|
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
|
||||||
|
|
||||||
slog.Debug("session_id_masking_applied",
|
slog.Debug("session_id_masking_applied",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
@@ -346,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
|||||||
"after", newUserID,
|
"after", newUserID,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata["user_id"] = newUserID
|
if newUserID == userID {
|
||||||
|
|
||||||
// 只重新序列化 metadata 字段
|
|
||||||
newMetadataRaw, marshalErr := json.Marshal(metadata)
|
|
||||||
if marshalErr != nil {
|
|
||||||
return newBody, nil
|
return newBody, nil
|
||||||
}
|
}
|
||||||
reqMap["metadata"] = newMetadataRaw
|
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID)
|
||||||
|
if setErr != nil {
|
||||||
|
return newBody, nil
|
||||||
|
}
|
||||||
|
return maskedBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
||||||
|
|||||||
82
backend/internal/service/identity_service_order_test.go
Normal file
82
backend/internal/service/identity_service_order_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type identityCacheStub struct {
|
||||||
|
maskedSessionID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) {
|
||||||
|
return s.maskedSessionID, nil
|
||||||
|
}
|
||||||
|
func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error {
|
||||||
|
s.maskedSessionID = sessionID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
cache := &identityCacheStub{}
|
||||||
|
svc := NewIdentityService(cache)
|
||||||
|
|
||||||
|
originalUserID := FormatMetadataUserID(
|
||||||
|
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||||
|
"",
|
||||||
|
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
|
||||||
|
"2.1.78",
|
||||||
|
)
|
||||||
|
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
|
||||||
|
|
||||||
|
result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
|
||||||
|
require.NotContains(t, resultStr, originalUserID)
|
||||||
|
require.Contains(t, resultStr, `"metadata":{"user_id":"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||||
|
cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"}
|
||||||
|
svc := NewIdentityService(cache)
|
||||||
|
|
||||||
|
originalUserID := FormatMetadataUserID(
|
||||||
|
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||||
|
"",
|
||||||
|
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
|
||||||
|
"2.1.78",
|
||||||
|
)
|
||||||
|
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"session_id_masking_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
|
||||||
|
require.Contains(t, resultStr, cache.maskedSessionID)
|
||||||
|
require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func strconvQuote(v string) string {
|
||||||
|
return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"`
|
||||||
|
}
|
||||||
104
backend/internal/service/metadata_userid.go
Normal file
104
backend/internal/service/metadata_userid.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
|
||||||
|
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
|
||||||
|
const NewMetadataFormatMinVersion = "2.1.78"
|
||||||
|
|
||||||
|
// ParsedUserID represents the components extracted from a metadata.user_id value.
|
||||||
|
type ParsedUserID struct {
|
||||||
|
DeviceID string // 64-char hex (or arbitrary client id)
|
||||||
|
AccountUUID string // may be empty
|
||||||
|
SessionID string // UUID
|
||||||
|
IsNewFormat bool // true if the original was JSON format
|
||||||
|
}
|
||||||
|
|
||||||
|
// legacyUserIDRegex matches the legacy user_id format:
|
||||||
|
//
|
||||||
|
// user_{64hex}_account_{optional_uuid}_session_{uuid}
|
||||||
|
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
|
||||||
|
|
||||||
|
// jsonUserID is the JSON structure for the new metadata.user_id format.
|
||||||
|
type jsonUserID struct {
|
||||||
|
DeviceID string `json:"device_id"`
|
||||||
|
AccountUUID string `json:"account_uuid"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseMetadataUserID parses a metadata.user_id string in either format.
|
||||||
|
// Returns nil if the input cannot be parsed.
|
||||||
|
func ParseMetadataUserID(raw string) *ParsedUserID {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try JSON format first (starts with '{')
|
||||||
|
if raw[0] == '{' {
|
||||||
|
var j jsonUserID
|
||||||
|
if err := json.Unmarshal([]byte(raw), &j); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if j.DeviceID == "" || j.SessionID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ParsedUserID{
|
||||||
|
DeviceID: j.DeviceID,
|
||||||
|
AccountUUID: j.AccountUUID,
|
||||||
|
SessionID: j.SessionID,
|
||||||
|
IsNewFormat: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try legacy format
|
||||||
|
matches := legacyUserIDRegex.FindStringSubmatch(raw)
|
||||||
|
if matches == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ParsedUserID{
|
||||||
|
DeviceID: matches[1],
|
||||||
|
AccountUUID: matches[2],
|
||||||
|
SessionID: matches[3],
|
||||||
|
IsNewFormat: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatMetadataUserID builds a metadata.user_id string in the format
|
||||||
|
// appropriate for the given CLI version. Components are the rewritten values
|
||||||
|
// (not necessarily the originals).
|
||||||
|
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
|
||||||
|
if IsNewMetadataFormatVersion(uaVersion) {
|
||||||
|
b, _ := json.Marshal(jsonUserID{
|
||||||
|
DeviceID: deviceID,
|
||||||
|
AccountUUID: accountUUID,
|
||||||
|
SessionID: sessionID,
|
||||||
|
})
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
// Legacy format
|
||||||
|
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
|
||||||
|
// new JSON metadata.user_id format (>= 2.1.78).
|
||||||
|
func IsNewMetadataFormatVersion(version string) bool {
|
||||||
|
if version == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
|
||||||
|
// Returns "" if the UA doesn't match the expected pattern.
|
||||||
|
func ExtractCLIVersion(ua string) string {
|
||||||
|
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
|
||||||
|
if len(matches) >= 2 {
|
||||||
|
return matches[1]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
183
backend/internal/service/metadata_userid_test.go
Normal file
183
backend/internal/service/metadata_userid_test.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============ ParseMetadataUserID Tests ============
|
||||||
|
|
||||||
|
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
|
||||||
|
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
|
||||||
|
parsed := ParseMetadataUserID(raw)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
|
||||||
|
require.Equal(t, "", parsed.AccountUUID)
|
||||||
|
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
|
||||||
|
require.False(t, parsed.IsNewFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
|
||||||
|
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
|
||||||
|
parsed := ParseMetadataUserID(raw)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
|
||||||
|
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
|
||||||
|
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
|
||||||
|
require.False(t, parsed.IsNewFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
|
||||||
|
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
|
||||||
|
parsed := ParseMetadataUserID(raw)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
|
||||||
|
require.Equal(t, "", parsed.AccountUUID)
|
||||||
|
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
|
||||||
|
require.True(t, parsed.IsNewFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
|
||||||
|
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
|
||||||
|
parsed := ParseMetadataUserID(raw)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
|
||||||
|
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
|
||||||
|
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
|
||||||
|
require.True(t, parsed.IsNewFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
raw string
|
||||||
|
}{
|
||||||
|
{"empty string", ""},
|
||||||
|
{"whitespace only", " "},
|
||||||
|
{"random text", "not-a-valid-user-id"},
|
||||||
|
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
|
||||||
|
{"invalid JSON", `{"device_id":}`},
|
||||||
|
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
|
||||||
|
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
|
||||||
|
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
|
||||||
|
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
|
||||||
|
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
|
||||||
|
// Legacy format should accept both upper and lower case hex
|
||||||
|
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
|
||||||
|
parsed := ParseMetadataUserID(rawUpper)
|
||||||
|
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
|
||||||
|
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ FormatMetadataUserID Tests ============
|
||||||
|
|
||||||
|
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
|
||||||
|
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
|
||||||
|
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
|
||||||
|
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
|
||||||
|
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
|
||||||
|
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
|
||||||
|
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
|
||||||
|
// Legacy format with empty account UUID → double underscore
|
||||||
|
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
|
||||||
|
require.Contains(t, result, "_account__session_")
|
||||||
|
|
||||||
|
// New format with empty account UUID → empty string in JSON
|
||||||
|
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
|
||||||
|
require.Contains(t, result, `"account_uuid":""`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ IsNewMetadataFormatVersion Tests ============
|
||||||
|
|
||||||
|
func TestIsNewMetadataFormatVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
version string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"", false},
|
||||||
|
{"2.1.77", false},
|
||||||
|
{"2.1.78", true},
|
||||||
|
{"2.1.79", true},
|
||||||
|
{"2.2.0", true},
|
||||||
|
{"3.0.0", true},
|
||||||
|
{"2.0.100", false},
|
||||||
|
{"1.9.99", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.version, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Round-trip Tests ============
|
||||||
|
|
||||||
|
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
|
||||||
|
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||||
|
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
sessionID := "123e4567-e89b-12d3-a456-426614174000"
|
||||||
|
|
||||||
|
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
|
||||||
|
parsed := ParseMetadataUserID(formatted)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, deviceID, parsed.DeviceID)
|
||||||
|
require.Equal(t, accountUUID, parsed.AccountUUID)
|
||||||
|
require.Equal(t, sessionID, parsed.SessionID)
|
||||||
|
require.False(t, parsed.IsNewFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
|
||||||
|
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||||
|
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
sessionID := "123e4567-e89b-12d3-a456-426614174000"
|
||||||
|
|
||||||
|
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
|
||||||
|
parsed := ParseMetadataUserID(formatted)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, deviceID, parsed.DeviceID)
|
||||||
|
require.Equal(t, accountUUID, parsed.AccountUUID)
|
||||||
|
require.Equal(t, sessionID, parsed.SessionID)
|
||||||
|
require.True(t, parsed.IsNewFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
|
||||||
|
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||||
|
sessionID := "123e4567-e89b-12d3-a456-426614174000"
|
||||||
|
|
||||||
|
// Legacy round-trip with empty account UUID
|
||||||
|
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
|
||||||
|
parsed := ParseMetadataUserID(formatted)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, deviceID, parsed.DeviceID)
|
||||||
|
require.Equal(t, "", parsed.AccountUUID)
|
||||||
|
require.Equal(t, sessionID, parsed.SessionID)
|
||||||
|
|
||||||
|
// JSON round-trip with empty account UUID
|
||||||
|
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
|
||||||
|
parsed = ParseMetadataUserID(formatted)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, deviceID, parsed.DeviceID)
|
||||||
|
require.Equal(t, "", parsed.AccountUUID)
|
||||||
|
require.Equal(t, sessionID, parsed.SessionID)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user