mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
268 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0236b97d49 | ||
|
|
26f6b1eeff | ||
|
|
dc447ccebe | ||
|
|
7ec29638f4 | ||
|
|
4c9562af20 | ||
|
|
71942fd322 | ||
|
|
550b979ac5 | ||
|
|
3878a5a46f | ||
|
|
e443a6a1ea | ||
|
|
963494ec6f | ||
|
|
525cdb8830 | ||
|
|
a6764e82f2 | ||
|
|
8027531d07 | ||
|
|
30706355a4 | ||
|
|
dfe99507b8 | ||
|
|
c1717c9a6c | ||
|
|
1fd1a58a7a | ||
|
|
fad07507be | ||
|
|
a20c211162 | ||
|
|
9f6ab6b817 | ||
|
|
bf3d6c0e6e | ||
|
|
241023f3fc | ||
|
|
1292c44b41 | ||
|
|
b4fce47049 | ||
|
|
e7780cd8c8 | ||
|
|
af96c8ea53 | ||
|
|
7d26b81075 | ||
|
|
b8ada63ac3 | ||
|
|
cfaac12af1 | ||
|
|
6028efd26c | ||
|
|
62a566ef2c | ||
|
|
94419f434c | ||
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
bd9d2671d7 | ||
|
|
62b40636e0 | ||
|
|
eeff451bc5 | ||
|
|
56fcb20f94 | ||
|
|
7134266acf | ||
|
|
2e4ac88ad9 | ||
|
|
51547fa216 | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
aa6047c460 | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 | ||
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
6826149a8f | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
c9debc50b1 | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa |
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
||||
# Go 源代码文件
|
||||
*.go text eol=lf
|
||||
|
||||
# 前端 源代码文件
|
||||
*.ts text eol=lf
|
||||
*.tsx text eol=lf
|
||||
*.js text eol=lf
|
||||
*.jsx text eol=lf
|
||||
*.vue text eol=lf
|
||||
|
||||
# Shell 脚本
|
||||
*.sh text eol=lf
|
||||
|
||||
|
||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -271,3 +271,36 @@ jobs:
|
||||
parse_mode: "Markdown",
|
||||
disable_web_page_preview: true
|
||||
}')"
|
||||
|
||||
sync-version-file:
|
||||
needs: [release]
|
||||
if: ${{ needs.release.result == 'success' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout default branch
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Sync VERSION file to released tag
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
VERSION=${{ github.event.inputs.tag }}
|
||||
VERSION=${VERSION#v}
|
||||
else
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
fi
|
||||
|
||||
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
|
||||
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
|
||||
echo "VERSION file already matches $VERSION"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "$VERSION" > backend/cmd/server/VERSION
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
|
||||
git push origin HEAD:${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -47,6 +47,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
|
||||
@@ -63,6 +63,8 @@ dockers:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -76,6 +78,8 @@ dockers:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -89,6 +93,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -102,6 +108,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
|
||||
31
Dockerfile
31
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,21 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
@@ -102,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||
|
||||
# Switch to non-root user
|
||||
USER sub2api
|
||||
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
# Expose port (can be overridden by SERVER_PORT env var)
|
||||
EXPOSE 8080
|
||||
@@ -112,5 +132,6 @@ EXPOSE 8080
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
# Run the application
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/sub2api"]
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||
# restore work in the runtime container without requiring Docker socket access.
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
USER sub2api
|
||||
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/sub2api"]
|
||||
|
||||
40
README.md
40
README.md
@@ -8,27 +8,31 @@
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||
|
||||
---
|
||||
|
||||
## Demo
|
||||
|
||||
Try Sub2API online: **https://demo.sub2api.org/**
|
||||
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||
|
||||
| Email | Password |
|
||||
|-------|----------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## Overview
|
||||
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
|
||||
## Features
|
||||
|
||||
@@ -39,6 +43,25 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||
- **Rate Limiting** - Configurable request and token rate limits
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||
|
||||
## Don't Want to Self-Host?
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
Community projects that extend or integrate with Sub2API:
|
||||
|
||||
| Project | Description | Features |
|
||||
|---------|-------------|----------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||
|
||||
## Tech Stack
|
||||
|
||||
@@ -51,10 +74,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
## Nginx Reverse Proxy Note
|
||||
|
||||
- Dependency Security: `docs/dependency-security.md`
|
||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
||||
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||
|
||||
---
|
||||
|
||||
|
||||
43
README_CN.md
43
README_CN.md
@@ -8,27 +8,30 @@
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**AI API 网关平台 - 订阅配额分发管理**
|
||||
|
||||
[English](README.md) | 中文
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||
---
|
||||
|
||||
## 在线体验
|
||||
|
||||
体验地址:**https://v2.pincc.ai/**
|
||||
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||
|
||||
| 邮箱 | 密码 |
|
||||
|------|------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## 项目概述
|
||||
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
|
||||
## 核心功能
|
||||
|
||||
@@ -39,6 +42,25 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
- **并发控制** - 用户级和账号级并发限制
|
||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||
|
||||
## 不想自建?试试官方中转
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
围绕 Sub2API 的社区扩展与集成项目:
|
||||
|
||||
| 项目 | 说明 | 功能 |
|
||||
|------|------|------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -51,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
|
||||
---
|
||||
|
||||
## 文档
|
||||
## Nginx 反向代理注意事项
|
||||
|
||||
- 依赖安全:`docs/dependency-security.md`
|
||||
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||
|
||||
---
|
||||
|
||||
## OpenAI Responses 兼容注意事项
|
||||
|
||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||
|
||||
## 部署方式
|
||||
|
||||
### 方式一:脚本安装(推荐)
|
||||
|
||||
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// Privacy client factory for OpenAI training opt-out
|
||||
providePrivacyClientFactory,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
@@ -104,11 +105,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
@@ -122,6 +123,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
@@ -130,20 +132,26 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageCache := service.NewUsageCache()
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -160,11 +168,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -199,7 +207,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -226,11 +234,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -245,6 +253,10 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -279,6 +291,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -414,6 +427,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -716,6 +716,7 @@ var (
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "request_id", Type: field.TypeString, Size: 64},
|
||||
{Name: "model", Type: field.TypeString, Size: 100},
|
||||
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
||||
@@ -755,31 +756,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -788,32 +789,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@@ -828,17 +829,17 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
|
||||
id *int64
|
||||
request_id *string
|
||||
model *string
|
||||
upstream_model *string
|
||||
input_tokens *int
|
||||
addinput_tokens *int
|
||||
output_tokens *int
|
||||
@@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
|
||||
m.model = nil
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (m *UsageLogMutation) SetUpstreamModel(s string) {
|
||||
m.upstream_model = &s
|
||||
}
|
||||
|
||||
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
|
||||
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
|
||||
v := m.upstream_model
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
|
||||
}
|
||||
return oldValue.UpstreamModel, nil
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (m *UsageLogMutation) ClearUpstreamModel() {
|
||||
m.upstream_model = nil
|
||||
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
|
||||
}
|
||||
|
||||
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) UpstreamModelCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetUpstreamModel resets all changes to the "upstream_model" field.
|
||||
func (m *UsageLogMutation) ResetUpstreamModel() {
|
||||
m.upstream_model = nil
|
||||
delete(m.clearedFields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (m *UsageLogMutation) SetGroupID(i int64) {
|
||||
m.group = &i
|
||||
@@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UsageLogMutation) Fields() []string {
|
||||
fields := make([]string, 0, 32)
|
||||
fields := make([]string, 0, 33)
|
||||
if m.user != nil {
|
||||
fields = append(fields, usagelog.FieldUserID)
|
||||
}
|
||||
@@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
|
||||
if m.model != nil {
|
||||
fields = append(fields, usagelog.FieldModel)
|
||||
}
|
||||
if m.upstream_model != nil {
|
||||
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
if m.group != nil {
|
||||
fields = append(fields, usagelog.FieldGroupID)
|
||||
}
|
||||
@@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.RequestID()
|
||||
case usagelog.FieldModel:
|
||||
return m.Model()
|
||||
case usagelog.FieldUpstreamModel:
|
||||
return m.UpstreamModel()
|
||||
case usagelog.FieldGroupID:
|
||||
return m.GroupID()
|
||||
case usagelog.FieldSubscriptionID:
|
||||
@@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
||||
return m.OldRequestID(ctx)
|
||||
case usagelog.FieldModel:
|
||||
return m.OldModel(ctx)
|
||||
case usagelog.FieldUpstreamModel:
|
||||
return m.OldUpstreamModel(ctx)
|
||||
case usagelog.FieldGroupID:
|
||||
return m.OldGroupID(ctx)
|
||||
case usagelog.FieldSubscriptionID:
|
||||
@@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetModel(v)
|
||||
return nil
|
||||
case usagelog.FieldUpstreamModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetUpstreamModel(v)
|
||||
return nil
|
||||
case usagelog.FieldGroupID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
@@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
||||
// mutation.
|
||||
func (m *UsageLogMutation) ClearedFields() []string {
|
||||
var fields []string
|
||||
if m.FieldCleared(usagelog.FieldUpstreamModel) {
|
||||
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldGroupID) {
|
||||
fields = append(fields, usagelog.FieldGroupID)
|
||||
}
|
||||
@@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
|
||||
// error if the field is not defined in the schema.
|
||||
func (m *UsageLogMutation) ClearField(name string) error {
|
||||
switch name {
|
||||
case usagelog.FieldUpstreamModel:
|
||||
m.ClearUpstreamModel()
|
||||
return nil
|
||||
case usagelog.FieldGroupID:
|
||||
m.ClearGroupID()
|
||||
return nil
|
||||
@@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
||||
case usagelog.FieldModel:
|
||||
m.ResetModel()
|
||||
return nil
|
||||
case usagelog.FieldUpstreamModel:
|
||||
m.ResetUpstreamModel()
|
||||
return nil
|
||||
case usagelog.FieldGroupID:
|
||||
m.ResetGroupID()
|
||||
return nil
|
||||
|
||||
@@ -821,92 +821,96 @@ func init() {
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
|
||||
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
|
||||
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||
usagelogDescInputTokens := usagelogFields[7].Descriptor()
|
||||
usagelogDescInputTokens := usagelogFields[8].Descriptor()
|
||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
|
||||
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
|
||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
|
||||
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
|
||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
|
||||
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
|
||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
|
||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
|
||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||
usagelogDescInputCost := usagelogFields[13].Descriptor()
|
||||
usagelogDescInputCost := usagelogFields[14].Descriptor()
|
||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||
usagelogDescOutputCost := usagelogFields[14].Descriptor()
|
||||
usagelogDescOutputCost := usagelogFields[15].Descriptor()
|
||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
|
||||
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
|
||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
|
||||
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
|
||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||
usagelogDescTotalCost := usagelogFields[17].Descriptor()
|
||||
usagelogDescTotalCost := usagelogFields[18].Descriptor()
|
||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||
usagelogDescActualCost := usagelogFields[18].Descriptor()
|
||||
usagelogDescActualCost := usagelogFields[19].Descriptor()
|
||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
|
||||
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
|
||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
||||
usagelogDescBillingType := usagelogFields[22].Descriptor()
|
||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||
// usagelogDescStream is the schema descriptor for stream field.
|
||||
usagelogDescStream := usagelogFields[22].Descriptor()
|
||||
usagelogDescStream := usagelogFields[23].Descriptor()
|
||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
||||
usagelogDescUserAgent := usagelogFields[26].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
||||
usagelogDescIPAddress := usagelogFields[27].Descriptor()
|
||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
||||
usagelogDescImageCount := usagelogFields[28].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
||||
usagelogDescImageSize := usagelogFields[29].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
||||
usagelogDescMediaType := usagelogFields[30].Descriptor()
|
||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
|
||||
field.String("model").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
// UpstreamModel stores the actual upstream model name when model mapping
|
||||
// is applied. NULL means no mapping — the requested model was used as-is.
|
||||
field.String("upstream_model").
|
||||
MaxLen(100).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Int64("group_id").
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
@@ -32,6 +32,8 @@ type UsageLog struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
// Model holds the value of the "model" field.
|
||||
Model string `json:"model,omitempty"`
|
||||
// UpstreamModel holds the value of the "upstream_model" field.
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
// GroupID holds the value of the "group_id" field.
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
// SubscriptionID holds the value of the "subscription_id" field.
|
||||
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Model = value.String
|
||||
}
|
||||
case usagelog.FieldUpstreamModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpstreamModel = new(string)
|
||||
*_m.UpstreamModel = value.String
|
||||
}
|
||||
case usagelog.FieldGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString("model=")
|
||||
builder.WriteString(_m.Model)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.UpstreamModel; v != nil {
|
||||
builder.WriteString("upstream_model=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.GroupID; v != nil {
|
||||
builder.WriteString("group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
|
||||
@@ -24,6 +24,8 @@ const (
|
||||
FieldRequestID = "request_id"
|
||||
// FieldModel holds the string denoting the model field in the database.
|
||||
FieldModel = "model"
|
||||
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||
FieldUpstreamModel = "upstream_model"
|
||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||
FieldGroupID = "group_id"
|
||||
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
||||
@@ -135,6 +137,7 @@ var Columns = []string{
|
||||
FieldAccountID,
|
||||
FieldRequestID,
|
||||
FieldModel,
|
||||
FieldUpstreamModel,
|
||||
FieldGroupID,
|
||||
FieldSubscriptionID,
|
||||
FieldInputTokens,
|
||||
@@ -179,6 +182,8 @@ var (
|
||||
RequestIDValidator func(string) error
|
||||
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
||||
ModelValidator func(string) error
|
||||
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
UpstreamModelValidator func(string) error
|
||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||
DefaultInputTokens int
|
||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpstreamModel orders the results by the upstream_model field.
|
||||
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByGroupID orders the results by the group_id field.
|
||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||
|
||||
@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
|
||||
func UpstreamModel(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||
func GroupID(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
|
||||
func UpstreamModelEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
|
||||
func UpstreamModelNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
|
||||
func UpstreamModelIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
|
||||
}
|
||||
|
||||
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
|
||||
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
|
||||
}
|
||||
|
||||
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
|
||||
func UpstreamModelGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
|
||||
func UpstreamModelGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
|
||||
func UpstreamModelLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
|
||||
func UpstreamModelLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
|
||||
func UpstreamModelContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
|
||||
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
|
||||
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
|
||||
func UpstreamModelIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
|
||||
}
|
||||
|
||||
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
|
||||
func UpstreamModelNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
|
||||
}
|
||||
|
||||
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
|
||||
func UpstreamModelEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
|
||||
func UpstreamModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||
func GroupIDEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||
|
||||
@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
|
||||
_c.mutation.SetUpstreamModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetUpstreamModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
||||
_c.mutation.SetGroupID(v)
|
||||
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.InputTokens(); !ok {
|
||||
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
||||
}
|
||||
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
_node.Model = value
|
||||
}
|
||||
if value, ok := _c.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
_node.UpstreamModel = &value
|
||||
}
|
||||
if value, ok := _c.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
_node.InputTokens = value
|
||||
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldUpstreamModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldUpstreamModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldUpstreamModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldGroupID, v)
|
||||
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetUpstreamModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetUpstreamModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetUpstreamModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetUpstreamModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
|
||||
_u.mutation.ClearUpstreamModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.Model(); ok {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.UpstreamModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
}
|
||||
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetUpstreamModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUpstreamModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearUpstreamModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if value, ok := _u.mutation.Model(); ok {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.UpstreamModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -66,7 +66,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.1 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
|
||||
@@ -22,8 +22,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||
@@ -58,8 +58,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
@@ -94,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
@@ -234,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -269,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -322,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -81,13 +82,15 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
// Claude Haiku → Sonnet(无 Haiku 支持)
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5": "claude-sonnet-4-6",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
|
||||
// Gemini 2.5 白名单
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
@@ -111,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||
var DefaultBedrockModelMapping = map[string]string{
|
||||
// Claude Opus
|
||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
// Claude Sonnet
|
||||
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
// Claude Haiku
|
||||
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -165,6 +165,8 @@ type AccountWithConcurrency struct {
|
||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||
}
|
||||
|
||||
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||
|
||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(account),
|
||||
@@ -226,7 +228,20 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if groupIDStr == accountListGroupUngroupedQueryValue {
|
||||
groupID = service.AccountListGroupUngrouped
|
||||
} else {
|
||||
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||
return
|
||||
}
|
||||
if parsedGroupID < 0 {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||
return
|
||||
}
|
||||
groupID = parsedGroupID
|
||||
}
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||
@@ -628,6 +643,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -658,7 +674,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
@@ -864,6 +880,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
@@ -1492,7 +1511,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
||||
}
|
||||
|
||||
// GetUsage handles getting account usage information
|
||||
// GET /api/v1/admin/accounts/:id/usage
|
||||
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -1500,7 +1519,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
source := c.DefaultQuery("source", "active")
|
||||
|
||||
var usage *service.UsageInfo
|
||||
if source == "passive" {
|
||||
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||
} else {
|
||||
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
}
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -1714,13 +1740,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||
|
||||
|
||||
@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
@@ -429,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
205
backend/internal/handler/admin/backup_handler.go
Normal file
205
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Accepted(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Accepted(c, record)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -249,11 +250,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -271,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
modelSource := usagestats.ModelSourceRequested
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
@@ -295,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
|
||||
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||
return
|
||||
}
|
||||
modelSource = rawModelSource
|
||||
}
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
@@ -321,11 +331,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
@@ -391,11 +402,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
@@ -416,11 +428,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -442,11 +455,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -461,9 +475,62 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func parseRankingLimit(raw string) int {
|
||||
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || limit <= 0 {
|
||||
return 12
|
||||
}
|
||||
if limit > 50 {
|
||||
return 50
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||
// GET /api/v1/admin/dashboard/users-ranking
|
||||
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Limit int `json:"limit"`
|
||||
}{
|
||||
Start: startTime.UTC().Format(time.RFC3339),
|
||||
End: endTime.UTC().Format(time.RFC3339),
|
||||
Limit: limit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user spending ranking")
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"total_requests": ranking.TotalRequests,
|
||||
"total_tokens": ranking.TotalTokens,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
@@ -546,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||
// GET /api/v1/admin/dashboard/user-breakdown
|
||||
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
dim := usagestats.UserBreakdownDimension{}
|
||||
if v := c.Query("group_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.GroupID = id
|
||||
}
|
||||
}
|
||||
dim.Model = c.Query("model")
|
||||
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
|
||||
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||
return
|
||||
}
|
||||
dim.ModelType = rawModelSource
|
||||
dim.Endpoint = c.Query("endpoint")
|
||||
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||
|
||||
limit := 50
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||
c.Request.Context(), startTime, endTime, dim, limit,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"users": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCacheProbe struct {
|
||||
service.UsageLogRepository
|
||||
trendCalls atomic.Int32
|
||||
usersTrendCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
r.trendCalls.Add(1)
|
||||
return []usagestats.TrendDataPoint{{
|
||||
Date: "2026-03-11",
|
||||
Requests: 1,
|
||||
TotalTokens: 2,
|
||||
Cost: 3,
|
||||
ActualCost: 4,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
limit int,
|
||||
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
r.usersTrendCalls.Add(1)
|
||||
return []usagestats.UserUsageTrendPoint{{
|
||||
Date: "2026-03-11",
|
||||
UserID: 1,
|
||||
Email: "cache@test.dev",
|
||||
Requests: 2,
|
||||
Tokens: 20,
|
||||
Cost: 2,
|
||||
ActualCost: 1,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func resetDashboardReadCachesForTest() {
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||
}
|
||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
rankingLimit int
|
||||
ranking []usagestats.UserSpendingRankingItem
|
||||
rankingTotal float64
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
@@ -49,6 +52,20 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
limit int,
|
||||
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
s.rankingLimit = limit
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
TotalRequests: 44,
|
||||
TotalTokens: 1234,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
@@ -56,6 +73,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -130,3 +148,54 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsValidModelSource(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
repo := &dashboardUsageRepoCapture{
|
||||
ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||
},
|
||||
rankingTotal: 88.8,
|
||||
}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock repo ---
|
||||
|
||||
type userBreakdownRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
capturedDim usagestats.UserBreakdownDimension
|
||||
capturedLimit int
|
||||
result []usagestats.UserBreakdownItem
|
||||
}
|
||||
|
||||
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||
_ context.Context, _, _ time.Time,
|
||||
dim usagestats.UserBreakdownDimension, limit int,
|
||||
) ([]usagestats.UserBreakdownItem, error) {
|
||||
r.capturedDim = dim
|
||||
r.capturedLimit = limit
|
||||
if r.result != nil {
|
||||
return r.result, nil
|
||||
}
|
||||
return []usagestats.UserBreakdownItem{}, nil
|
||||
}
|
||||
|
||||
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
h := NewDashboardHandler(svc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||
return router
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||
require.Empty(t, repo.capturedDim.Model)
|
||||
require.Empty(t, repo.capturedDim.Endpoint)
|
||||
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
|
||||
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 100, repo.capturedLimit)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
// limit > 200 should fall back to default 50
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 50, repo.capturedLimit)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{
|
||||
result: []usagestats.UserBreakdownItem{
|
||||
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||
},
|
||||
}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Users, 2)
|
||||
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.Data.Users)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||
require.Empty(t, repo.capturedDim.Model)
|
||||
require.Empty(t, repo.capturedDim.Endpoint)
|
||||
}
|
||||
203
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
203
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
var (
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
)
|
||||
|
||||
type dashboardTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardModelGroupCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
ModelSource string `json:"model_source,omitempty"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardEntityTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func cacheStatusValue(hit bool) string {
|
||||
if hit {
|
||||
return "hit"
|
||||
}
|
||||
return "miss"
|
||||
}
|
||||
|
||||
func mustMarshalDashboardCacheKey(value any) string {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||
typed, ok := payload.(T)
|
||||
if !ok {
|
||||
var zero T
|
||||
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUsageTrendCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
modelSource string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
ModelSource: usagestats.NormalizeModelSource(modelSource),
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getGroupStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.GroupStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||
return h.buildSnapshotV2Response(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters,
|
||||
includeStats,
|
||||
includeTrend,
|
||||
includeModels,
|
||||
includeGroups,
|
||||
includeUsersTrend,
|
||||
usersTrendLimit,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
response.Error(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
response.Success(c, cached.Payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
filters *dashboardSnapshotV2Filters,
|
||||
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||
usersTrendLimit int,
|
||||
) (*dashboardSnapshotV2Response, error) {
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get dashboard statistics")
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
trend, _, err := h.getUsageTrendCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
@@ -160,35 +186,34 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get usage trend")
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
models, _, err := h.getModelStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
usagestats.ModelSourceRequested,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get model statistics")
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
groups, _, err := h.getGroupStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -200,34 +225,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get group statistics")
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get user usage trend")
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,27 +17,80 @@ import (
|
||||
|
||||
// GroupHandler handles admin group management
|
||||
type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
adminService service.AdminService
|
||||
dashboardService *service.DashboardService
|
||||
groupCapacityService *service.GroupCapacityService
|
||||
}
|
||||
|
||||
type optionalLimitField struct {
|
||||
set bool
|
||||
value *float64
|
||||
}
|
||||
|
||||
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||
f.set = true
|
||||
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var number float64
|
||||
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
number, err = strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||
}
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||
}
|
||||
|
||||
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||
if !f.set {
|
||||
return nil
|
||||
}
|
||||
if f.value != nil {
|
||||
return f.value
|
||||
}
|
||||
zero := 0.0
|
||||
return &zero
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
adminService: adminService,
|
||||
adminService: adminService,
|
||||
dashboardService: dashboardService,
|
||||
groupCapacityService: groupCapacityService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -62,16 +119,16 @@ type CreateGroupRequest struct {
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -191,9 +248,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -244,9 +301,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -311,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
||||
_ = groupID // TODO: implement actual stats
|
||||
}
|
||||
|
||||
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||
userTZ := c.Query("timezone")
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
|
||||
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group usage summary")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||
// GET /api/v1/admin/groups/capacity-summary
|
||||
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group capacity summary")
|
||||
return
|
||||
}
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
// GetGroupAPIKeys handles getting API keys in a group
|
||||
// GET /api/v1/admin/groups/:id/api-keys
|
||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
@@ -335,6 +419,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||
type BatchSetGroupRateMultipliersRequest struct {
|
||||
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRateMultipliersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
Extra: nil,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
|
||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
"group_available_accounts",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_rate_limited_count",
|
||||
"account_error_count",
|
||||
"account_error_ratio",
|
||||
"overload_account_count",
|
||||
}
|
||||
|
||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent":
|
||||
"memory_usage_percent",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_error_ratio":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||
if req.Type == "" {
|
||||
req.Type = "balance"
|
||||
}
|
||||
|
||||
if req.Type == "subscription" {
|
||||
if req.GroupID == nil {
|
||||
response.BadRequest(c, "group_id is required for subscription type")
|
||||
return
|
||||
}
|
||||
if req.ValidityDays <= 0 {
|
||||
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
|
||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||
// parameter-validation layer that runs before any service call.
|
||||
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: newStubAdminService(),
|
||||
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||
}
|
||||
}
|
||||
|
||||
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||
// status code. For cases that pass validation and proceed into the service layer,
|
||||
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||
code = 0
|
||||
}
|
||||
}()
|
||||
handler.CreateAndRedeem(c)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-default",
|
||||
"value": 10.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"omitting type should default to balance and pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-no-group",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"validity_days": 30,
|
||||
// group_id 缺失
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
validityDays int
|
||||
}{
|
||||
{"zero", 0},
|
||||
{"negative", -1},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-bad-days-" + tc.name,
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": tc.validityDays,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-valid",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 31,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"valid subscription params should pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-no-extras",
|
||||
"type": "balance",
|
||||
"value": 50.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
FrontendURL: settings.FrontendURL,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Frontend URL 验证
|
||||
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||
if req.FrontendURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
FrontendURL: updatedSettings.FrontendURL,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.FrontendURL != after.FrontendURL {
|
||||
changed = append(changed, "frontend_url")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
@@ -952,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||
}
|
||||
|
||||
// GetOverloadCooldownSettings 获取529过载冷却配置
|
||||
// GET /api/v1/admin/settings/overload-cooldown
|
||||
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.OverloadCooldownSettings{
|
||||
Enabled: settings.Enabled,
|
||||
CooldownMinutes: settings.CooldownMinutes,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
|
||||
type UpdateOverloadCooldownSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
}
|
||||
|
||||
// UpdateOverloadCooldownSettings 更新529过载冷却配置
|
||||
// PUT /api/v1/admin/settings/overload-cooldown
|
||||
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
|
||||
var req UpdateOverloadCooldownSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.OverloadCooldownSettings{
|
||||
Enabled: req.Enabled,
|
||||
CooldownMinutes: req.CooldownMinutes,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.OverloadCooldownSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
CooldownMinutes: updatedSettings.CooldownMinutes,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||
// GET /api/v1/admin/settings/stream-timeout
|
||||
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
type snapshotCacheLoadResult struct {
|
||||
Entry snapshotCacheEntry
|
||||
Hit bool
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||
if load == nil {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return entry, true, nil
|
||||
}
|
||||
if c == nil || key == "" {
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
return c.Set(key, payload), false, nil
|
||||
}
|
||||
|
||||
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||
}
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
result, ok := value.(snapshotCacheLoadResult)
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
return result.Entry, result.Hit, nil
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
|
||||
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"hello": "world"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, hit)
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
|
||||
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"unexpected": "value"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, hit)
|
||||
require.Equal(t, entry.ETag, entry2.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
start := make(chan struct{})
|
||||
const callers = 8
|
||||
errCh := make(chan error, callers)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for range callers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||
loads.Add(1)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return "value", nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
status := c.Query("status")
|
||||
platform := c.Query("platform")
|
||||
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -216,6 +217,38 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Monthly bool `json:"monthly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
var req ResetSubscriptionQuotaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
ActiveAccountCount: g.ActiveAccountCount,
|
||||
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
@@ -264,8 +266,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -281,6 +283,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -496,8 +523,11 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
UpstreamModel: l.UpstreamModel,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
|
||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
inboundEndpoint := "/v1/chat/completions"
|
||||
upstreamEndpoint := "/v1/responses"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
InboundEndpoint: &inboundEndpoint,
|
||||
UpstreamEndpoint: &upstreamEndpoint,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, userDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
@@ -81,6 +82,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +115,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -152,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
type OverloadCooldownSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
type StreamTimeoutSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -122,9 +122,11 @@ type AdminGroup struct {
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||
|
||||
// 分组排序
|
||||
SortOrder int `json:"sort_order"`
|
||||
@@ -203,6 +205,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -322,11 +334,18 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||
// Omitted when no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
174
backend/internal/handler/endpoint.go
Normal file
174
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Canonical inbound / upstream endpoint paths.
|
||||
// All normalization and derivation reference this single set
|
||||
// of constants — add new paths HERE when a new API surface
|
||||
// is introduced.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointGeminiModels = "/v1beta/models"
|
||||
)
|
||||
|
||||
// gin.Context keys used by the middleware and helpers below.
|
||||
const (
|
||||
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Normalization functions
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||
//
|
||||
// "/antigravity/v1/messages" → "/v1/messages"
|
||||
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||
func NormalizeInboundEndpoint(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
switch {
|
||||
case strings.Contains(path, EndpointChatCompletions):
|
||||
return EndpointChatCompletions
|
||||
case strings.Contains(path, EndpointMessages):
|
||||
return EndpointMessages
|
||||
case strings.Contains(path, EndpointResponses):
|
||||
return EndpointResponses
|
||||
case strings.Contains(path, EndpointGeminiModels):
|
||||
return EndpointGeminiModels
|
||||
default:
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||
// account platform and the normalized inbound endpoint.
|
||||
//
|
||||
// Platform-specific rules:
|
||||
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||
// such as /v1/responses/compact preserved from the raw URL).
|
||||
// - Anthropic → /v1/messages
|
||||
// - Gemini → /v1beta/models
|
||||
// - Sora → /v1/chat/completions
|
||||
// - Antigravity routes may target either Claude or Gemini, so the
|
||||
// inbound endpoint is used to distinguish.
|
||||
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
inbound = strings.TrimSpace(inbound)
|
||||
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
// OpenAI forwards everything to the Responses API.
|
||||
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||
return EndpointResponses + suffix
|
||||
}
|
||||
return EndpointResponses
|
||||
|
||||
case service.PlatformAnthropic:
|
||||
return EndpointMessages
|
||||
|
||||
case service.PlatformGemini:
|
||||
return EndpointGeminiModels
|
||||
|
||||
case service.PlatformSora:
|
||||
return EndpointChatCompletions
|
||||
|
||||
case service.PlatformAntigravity:
|
||||
// Antigravity accounts serve both Claude and Gemini.
|
||||
if inbound == EndpointGeminiModels {
|
||||
return EndpointGeminiModels
|
||||
}
|
||||
return EndpointMessages
|
||||
}
|
||||
|
||||
// Unknown platform — fall back to inbound.
|
||||
return inbound
|
||||
}
|
||||
|
||||
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||
// Returns "" when there is no meaningful suffix.
|
||||
func responsesSubpathSuffix(rawPath string) string {
|
||||
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||
idx := strings.LastIndex(trimmed, "/responses")
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
suffix := trimmed[idx+len("/responses"):]
|
||||
if suffix == "" || suffix == "/" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(suffix, "/") {
|
||||
return ""
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Middleware
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||
// canonical inbound endpoint in gin.Context so that every handler in
|
||||
// the chain can read it via GetInboundEndpoint.
|
||||
//
|
||||
// Apply this middleware to all gateway route groups.
|
||||
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Context helpers — used by handlers before building
|
||||
// RecordUsageInput / RecordUsageLongContextInput.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||
func GetInboundEndpoint(c *gin.Context) string {
|
||||
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: normalize on the fly.
|
||||
path := ""
|
||||
if c != nil {
|
||||
path = c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
}
|
||||
return NormalizeInboundEndpoint(path)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||
// and the account platform. Handlers call this after scheduling an
|
||||
// account, passing account.Platform.
|
||||
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||
inbound := GetInboundEndpoint(c)
|
||||
rawPath := ""
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
rawPath = c.Request.URL.Path
|
||||
}
|
||||
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||
}
|
||||
159
backend/internal/handler/endpoint_test.go
Normal file
159
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() { gin.SetMode(gin.TestMode) }
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// NormalizeInboundEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
// Direct canonical paths.
|
||||
{"/v1/messages", EndpointMessages},
|
||||
{"/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1beta/models", EndpointGeminiModels},
|
||||
|
||||
// Prefixed paths (antigravity, openai, sora).
|
||||
{"/antigravity/v1/messages", EndpointMessages},
|
||||
{"/openai/v1/responses", EndpointResponses},
|
||||
{"/openai/v1/responses/compact", EndpointResponses},
|
||||
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||
|
||||
// Gin route patterns with wildcards.
|
||||
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||
{"/v1/responses/*subpath", EndpointResponses},
|
||||
|
||||
// Unknown path is returned as-is.
|
||||
{"/v1/embeddings", "/v1/embeddings"},
|
||||
{"", ""},
|
||||
{" /v1/messages ", EndpointMessages},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// DeriveUpstreamEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound string
|
||||
rawPath string
|
||||
platform string
|
||||
want string
|
||||
}{
|
||||
// Anthropic.
|
||||
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||
|
||||
// Gemini.
|
||||
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||
|
||||
// Sora.
|
||||
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||
|
||||
// OpenAI — always /v1/responses.
|
||||
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||
|
||||
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||
|
||||
// Unknown platform — passthrough.
|
||||
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// responsesSubpathSuffix
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{"/v1/responses", ""},
|
||||
{"/v1/responses/", ""},
|
||||
{"/v1/responses/compact", "/compact"},
|
||||
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||
{"/v1/messages", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.raw, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// InboundEndpointMiddleware + context helpers
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(InboundEndpointMiddleware())
|
||||
|
||||
var captured string
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
captured = GetInboundEndpoint(c)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, EndpointMessages, captured)
|
||||
}
|
||||
|
||||
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||
|
||||
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||
got := GetInboundEndpoint(c)
|
||||
require.Equal(t, EndpointMessages, got)
|
||||
}
|
||||
|
||||
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||
|
||||
// Simulate middleware.
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, "/v1/responses/compact", got)
|
||||
}
|
||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -434,19 +441,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -635,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -704,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -736,19 +760,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -909,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
||||
}
|
||||
if s := c.Query("end_date"); s != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
||||
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||
}
|
||||
}
|
||||
return startTime, endTime
|
||||
@@ -1185,6 +1219,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
||||
}
|
||||
}
|
||||
|
||||
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
@@ -1193,6 +1231,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
|
||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||
|
||||
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||
// 具体验证:
|
||||
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||
|
||||
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||
|
||||
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||
}
|
||||
|
||||
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||
|
||||
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||
|
||||
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx,
|
||||
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||
// c.Writer.Size() 仍为 -1
|
||||
|
||||
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||
require.False(t, guardTriggered,
|
||||
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
|
||||
@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
||||
return []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||
}`)
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
||||
System: []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||
}
|
||||
|
||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
||||
"system": []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||
})
|
||||
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||
|
||||
@@ -503,6 +503,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -510,8 +513,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
@@ -587,6 +593,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
||||
}
|
||||
}
|
||||
|
||||
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
286
backend/internal/handler/openai_chat_completions.go
Normal file
286
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "responses root maps to responses upstream",
|
||||
path: "/v1/responses",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
{
|
||||
name: "responses compact keeps compact suffix",
|
||||
path: "/openai/v1/responses/compact",
|
||||
want: "/v1/responses/compact",
|
||||
},
|
||||
{
|
||||
name: "responses nested suffix preserved",
|
||||
path: "/openai/v1/responses/compact/detail",
|
||||
want: "/v1/responses/compact/detail",
|
||||
},
|
||||
{
|
||||
name: "non responses path uses platform fallback",
|
||||
path: "/v1/messages",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -352,18 +352,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -653,14 +657,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@@ -732,17 +731,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1231,14 +1234,17 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
@@ -1429,6 +1435,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
||||
}
|
||||
}
|
||||
|
||||
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
@@ -1437,6 +1447,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,28 @@ const (
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
opsErrInvalidAPIKey = "invalid_api_key"
|
||||
opsErrAPIKeyRequired = "api_key_required"
|
||||
opsErrInsufficientBalance = "insufficient balance"
|
||||
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
)
|
||||
|
||||
const (
|
||||
opsErrorLogTimeout = 5 * time.Second
|
||||
opsErrorLogDrainTimeout = 10 * time.Second
|
||||
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||
|
||||
opsErrorLogMinWorkerCount = 4
|
||||
opsErrorLogMaxWorkerCount = 32
|
||||
@@ -38,6 +55,7 @@ const (
|
||||
opsErrorLogQueueSizePerWorker = 128
|
||||
opsErrorLogMinQueueSize = 256
|
||||
opsErrorLogMaxQueueSize = 8192
|
||||
opsErrorLogBatchSize = 32
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
@@ -82,27 +100,82 @@ func startOpsErrorLogWorkers() {
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer opsErrorLogWorkersWg.Done()
|
||||
for job := range opsErrorLogQueue {
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
for {
|
||||
job, ok := <-opsErrorLogQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||
batch = append(batch, job)
|
||||
|
||||
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < opsErrorLogBatchSize {
|
||||
select {
|
||||
case nextJob, ok := <-opsErrorLogQueue:
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch = append(batch, nextJob)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||
var processed int64
|
||||
for _, job := range batch {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||
processed++
|
||||
}
|
||||
if processed == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for opsSvc, entries := range grouped {
|
||||
if opsSvc == nil || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||
cancel()
|
||||
}
|
||||
opsErrorLogProcessed.Add(processed)
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
@@ -967,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE":
|
||||
case opsCodeInsufficientBalance:
|
||||
return "billing_error"
|
||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "subscription_error"
|
||||
default:
|
||||
return "api_error"
|
||||
@@ -981,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
}
|
||||
|
||||
@@ -1000,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@@ -1046,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1140,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if insufficient balance errors should be ignored
|
||||
if settings.IgnoreInsufficientBalanceErrors {
|
||||
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -399,17 +399,23 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
@@ -478,6 +484,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
|
||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
@@ -334,15 +334,32 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -431,6 +448,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
|
||||
@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// 设置结束时间为当天结束
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
// 使用 period 参数
|
||||
period := c.DefaultQuery("period", "today")
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -114,10 +124,68 @@ type IneligibleTier struct {
|
||||
type LoadCodeAssistResponse struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
|
||||
type PaidTierInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
|
||||
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias PaidTierInfo
|
||||
var raw alias
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = PaidTierInfo(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AvailableCredit 表示一条 AI Credits 余额记录。
|
||||
type AvailableCredit struct {
|
||||
CreditType string `json:"creditType,omitempty"`
|
||||
CreditAmount string `json:"creditAmount,omitempty"`
|
||||
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
|
||||
}
|
||||
|
||||
// GetAmount 将 creditAmount 解析为浮点数。
|
||||
func (c *AvailableCredit) GetAmount() float64 {
|
||||
if c.CreditAmount == "" {
|
||||
return 0
|
||||
}
|
||||
var value float64
|
||||
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
|
||||
return value
|
||||
}
|
||||
|
||||
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
|
||||
func (c *AvailableCredit) GetMinimumAmount() float64 {
|
||||
if c.MinimumCreditAmountForUsage == "" {
|
||||
return 0
|
||||
}
|
||||
var value float64
|
||||
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
|
||||
return value
|
||||
}
|
||||
|
||||
// OnboardUserRequest onboardUser 请求
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
@@ -147,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
|
||||
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
|
||||
if r.PaidTier == nil {
|
||||
return nil
|
||||
}
|
||||
return r.PaidTier.AvailableCredits
|
||||
}
|
||||
|
||||
// Client Antigravity API 客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
@@ -514,7 +590,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +613,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +663,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
|
||||
func TestGetTier_PaidTier优先(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||
PaidTier: &TierInfo{ID: "g1-pro-tier"},
|
||||
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
|
||||
}
|
||||
if got := resp.GetTier(); got != "g1-pro-tier" {
|
||||
t.Errorf("应返回 paidTier: got %s", got)
|
||||
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
|
||||
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||
PaidTier: &TierInfo{ID: ""},
|
||||
PaidTier: &PaidTierInfo{ID: ""},
|
||||
}
|
||||
// paidTier.ID 为空时应回退到 currentTier
|
||||
if got := resp.GetTier(); got != "free-tier" {
|
||||
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableCredits(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
PaidTier: &PaidTierInfo{
|
||||
ID: "g1-pro-tier",
|
||||
AvailableCredits: []AvailableCredit{
|
||||
{
|
||||
CreditType: "GOOGLE_ONE_AI",
|
||||
CreditAmount: "25",
|
||||
MinimumCreditAmountForUsage: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
credits := resp.GetAvailableCredits()
|
||||
if len(credits) != 1 {
|
||||
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
|
||||
}
|
||||
if credits[0].GetAmount() != 25 {
|
||||
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
|
||||
}
|
||||
if credits[0].GetMinimumAmount() != 5 {
|
||||
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTier_两者都为nil(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{}
|
||||
if got := resp.GetTier(); got != "" {
|
||||
|
||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||
var defaultUserAgentVersion = "1.20.5"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||
var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
||||
func filterSystemBlockByPrefix(text string) string {
|
||||
for _, prefix := range systemBlockFilterPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(sysStr)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(block.Text)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@ package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system json.RawMessage
|
||||
}{
|
||||
{
|
||||
name: "system array",
|
||||
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
|
||||
},
|
||||
{
|
||||
name: "system string",
|
||||
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
System: tt.system,
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.NotNil(t, req.Request.SystemInstruction)
|
||||
|
||||
found := false
|
||||
for _, part := range req.Request.SystemInstruction.Parts {
|
||||
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
@@ -1007,3 +1008,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
||||
// Should default to image/png when media_type is empty.
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeToolParameters tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeToolParameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input json.RawMessage
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
expected: `{"type":"object","properties":{}}`,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: json.RawMessage(``),
|
||||
expected: `{"type":"object","properties":{}}`,
|
||||
},
|
||||
{
|
||||
name: "null input",
|
||||
input: json.RawMessage(`null`),
|
||||
expected: `{"type":"object","properties":{}}`,
|
||||
},
|
||||
{
|
||||
name: "object without properties",
|
||||
input: json.RawMessage(`{"type":"object"}`),
|
||||
expected: `{"type":"object","properties":{}}`,
|
||||
},
|
||||
{
|
||||
name: "object with properties",
|
||||
input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||
expected: `{"type":"object","properties":{"city":{"type":"string"}}}`,
|
||||
},
|
||||
{
|
||||
name: "non-object type",
|
||||
input: json.RawMessage(`{"type":"string"}`),
|
||||
expected: `{"type":"string"}`,
|
||||
},
|
||||
{
|
||||
name: "object with additional fields preserved",
|
||||
input: json.RawMessage(`{"type":"object","required":["name"]}`),
|
||||
expected: `{"type":"object","required":["name"],"properties":{}}`,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON passthrough",
|
||||
input: json.RawMessage(`not json`),
|
||||
expected: `not json`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeToolParameters(tt.input)
|
||||
if tt.name == "invalid JSON passthrough" {
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
} else {
|
||||
assert.JSONEq(t, tt.expected, string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
Tools: []AnthropicTool{
|
||||
{Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name)
|
||||
|
||||
// Parameters must have "properties" field after normalization.
|
||||
var params map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||
assert.Contains(t, params, "properties")
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
Tools: []AnthropicTool{
|
||||
{Name: "simple_tool", Description: "A tool"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, resp.Tools, 1)
|
||||
var params map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||
assert.JSONEq(t, `"object"`, string(params["type"]))
|
||||
assert.JSONEq(t, `{}`, string(params["properties"]))
|
||||
}
|
||||
|
||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -410,8 +409,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
||||
Type: "function",
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.InputSchema,
|
||||
Parameters: normalizeToolParameters(t.InputSchema),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||
//
|
||||
// - nil/empty → {"type":"object","properties":{}}
|
||||
// - type=object without properties → adds "properties": {}
|
||||
// - otherwise → returned unchanged
|
||||
func normalizeToolParameters(schema json.RawMessage) json.RawMessage {
|
||||
if len(schema) == 0 || string(schema) == "null" {
|
||||
return json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(schema, &m); err != nil {
|
||||
return schema
|
||||
}
|
||||
|
||||
typ := m["type"]
|
||||
if string(typ) != `"object"` {
|
||||
return schema
|
||||
}
|
||||
|
||||
if _, ok := m["properties"]; ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
m["properties"] = json.RawMessage(`{}`)
|
||||
out, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return schema
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
810
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
810
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,810 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ChatCompletionsToResponses tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4o", resp.Model)
|
||||
assert.True(t, resp.Stream) // always forced true
|
||||
assert.False(t, *resp.Store)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
assert.Equal(t, "user", items[1].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "ping",
|
||||
Arguments: `{"host":"example.com"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: json.RawMessage(`"pong"`),
|
||||
},
|
||||
},
|
||||
Tools: []ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &ChatFunction{
|
||||
Name: "ping",
|
||||
Description: "Ping a host",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "call_1", items[2].CallID)
|
||||
assert.Equal(t, "pong", items[2].Output)
|
||||
|
||||
// Check tools
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||
t.Run("max_tokens", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
// Below minMaxOutputTokens (128), should be clamped
|
||||
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||
})
|
||||
|
||||
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
maxCompletion := 500
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
MaxCompletionTokens: &maxCompletion,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ReasoningEffort: "high",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(content)},
|
||||
},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "Describe this", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
Functions: []ChatFunction{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||
|
||||
// tool_choice should be converted
|
||||
require.NotNil(t, resp.ToolChoice)
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ServiceTier: "flex",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "flex", resp.ServiceTier)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`"Let me call a function."`),
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "do_thing",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + assistant message (with text) + function_call
|
||||
require.Len(t, items, 3)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResponsesToChatCompletions tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_123",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Hello, world!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
assert.Equal(t, "chat.completion", chat.Object)
|
||||
assert.Equal(t, "gpt-4o", chat.Model)
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "Hello, world!", content)
|
||||
|
||||
require.NotNil(t, chat.Usage)
|
||||
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_xyz",
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city":"NYC"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||
|
||||
msg := chat.Choices[0].Message
|
||||
require.Len(t, msg.ToolCalls, 1)
|
||||
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_789",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "reasoning",
|
||||
Summary: []ResponsesSummary{
|
||||
{Type: "summary_text", Text: "I thought about it."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "The answer is 42."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_inc",
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "partial..."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cache",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 10,
|
||||
TotalTokens: 110,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 80,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.NotNil(t, chat.Usage)
|
||||
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_ws",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "web_search_call",
|
||||
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "search results", content)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesEventToChatChunks tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
// response.created → role chunk
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{
|
||||
ID: "resp_stream",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||
assert.True(t, state.SentRole)
|
||||
|
||||
// response.output_text.delta → content chunk
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Hello",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 1,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_1",
|
||||
Name: "get_weather",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
assert.Equal(t, "call_1", tc.ID)
|
||||
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index)
|
||||
|
||||
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||
Delta: `{"city":`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||
|
||||
// Add a second function call at output_index=2
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 2,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_2",
|
||||
Name: "get_time",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||
|
||||
// Argument delta for second tool call
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 2,
|
||||
Delta: `{"tz":"UTC"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||
|
||||
// Argument delta for first tool call (interleaved)
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1,
|
||||
Delta: `"Tokyo"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 50,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 70,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
// finish chunk + usage chunk
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// First chunk: finish_reason
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Second chunk: usage
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SawToolCall = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
state.Usage = &ChatUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
|
||||
chunks := FinalizeResponsesChatStream(state)
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||
|
||||
// Idempotent: second call returns nil
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
// Simulate response.completed
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||
|
||||
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestChatChunkToSSE(t *testing.T) {
|
||||
chunk := ChatCompletionsChunk{
|
||||
ID: "chatcmpl-test",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1700000000,
|
||||
Model: "gpt-4o",
|
||||
Choices: []ChatChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Role: "assistant"},
|
||||
FinishReason: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sse, err := ChatChunkToSSE(chunk)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, sse, "data: ")
|
||||
assert.Contains(t, sse, "chatcmpl-test")
|
||||
assert.Contains(t, sse, "assistant")
|
||||
assert.True(t, len(sse) > 10)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stream round-trip test
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
var allChunks []ChatCompletionsChunk
|
||||
|
||||
// 1. response.created
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// 2. text deltas
|
||||
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: text,
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
}
|
||||
|
||||
// 3. response.completed
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 4,
|
||||
TotalTokens: 14,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||
require.Len(t, allChunks, 7)
|
||||
|
||||
// First chunk has role
|
||||
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||
|
||||
// Text chunks
|
||||
var fullText string
|
||||
for i := 1; i <= 4; i++ {
|
||||
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||
}
|
||||
assert.Equal(t, "Hello, world!", fullText)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, allChunks[6].Usage)
|
||||
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||
|
||||
// All chunks share the same ID
|
||||
for _, c := range allChunks {
|
||||
assert.Equal(t, "resp_rt", c.ID)
|
||||
}
|
||||
}
|
||||
385
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
385
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
// Responses API request. The upstream always streams, so Stream is forced to
|
||||
// true. store is always false and reasoning.encrypted_content is always
|
||||
// included so that the response translator has full context.
|
||||
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: true, // upstream always streams
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
ServiceTier: req.ServiceTier,
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||
maxTokens := 0
|
||||
if req.MaxTokens != nil {
|
||||
maxTokens = *req.MaxTokens
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
maxTokens = *req.MaxCompletionTokens
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
v := maxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||
if req.ReasoningEffort != "" {
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: req.ReasoningEffort,
|
||||
Summary: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// tools[] and legacy functions[] → ResponsesTool[]
|
||||
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||
}
|
||||
|
||||
// tool_choice: already compatible format — pass through directly.
|
||||
// Legacy function_call needs mapping.
|
||||
if len(req.ToolChoice) > 0 {
|
||||
out.ToolChoice = req.ToolChoice
|
||||
} else if len(req.FunctionCall) > 0 {
|
||||
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||
// array into a Responses API input items array.
|
||||
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
for _, m := range msgs {
|
||||
items, err := chatMessageToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||
// ResponsesInputItem values.
|
||||
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
return chatSystemToResponses(m)
|
||||
case "user":
|
||||
return chatUserToResponses(m)
|
||||
case "assistant":
|
||||
return chatAssistantToResponses(m)
|
||||
case "tool":
|
||||
return chatToolToResponses(m)
|
||||
case "function":
|
||||
return chatFunctionToResponses(m)
|
||||
default:
|
||||
return chatUserToResponses(m)
|
||||
}
|
||||
}
|
||||
|
||||
// chatSystemToResponses converts a system message.
|
||||
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
text, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, err := json.Marshal(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatUserToResponses converts a user message, handling both plain strings and
|
||||
// multi-modal content arrays.
|
||||
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string first.
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||
return nil, fmt.Errorf("parse user content: %w", err)
|
||||
}
|
||||
|
||||
var responseParts []ResponsesContentPart
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "image_url":
|
||||
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: p.ImageURL.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content, err := json.Marshal(responseParts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatAssistantToResponses converts an assistant message. If there is both
|
||||
// text content and tool_calls, the text is emitted as an assistant message
|
||||
// first, then each tool_call becomes a function_call item. If the content is
|
||||
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
}
|
||||
|
||||
// Emit one function_call item per tool_call.
|
||||
for _, tc := range m.ToolCalls {
|
||||
args := tc.Function.Arguments
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.ToolCallID,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// chatFunctionToResponses converts a legacy function result message
|
||||
// (role=function) into a function_call_output item. The Name field is used as
|
||||
// call_id since legacy function calls do not carry a separate call_id.
|
||||
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.Name,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return "", fmt.Errorf("parse content as string: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||
// function definitions to Responses API tool definitions.
|
||||
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" || t.Function == nil {
|
||||
continue
|
||||
}
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: t.Function.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
// Legacy functions[] are treated as function-type tools.
|
||||
for _, f := range functions {
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
Strict: f.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||
// Responses API tool_choice value.
|
||||
//
|
||||
// "auto" → "auto"
|
||||
// "none" → "none"
|
||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Object form: {"name":"X"}
|
||||
var obj struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": obj.Name},
|
||||
})
|
||||
}
|
||||
374
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
374
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||
// Completions response. Text output items are concatenated into
|
||||
// choices[0].message.content; function_call items become tool_calls.
|
||||
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||
id := resp.ID
|
||||
if id == "" {
|
||||
id = generateChatCmplID()
|
||||
}
|
||||
|
||||
out := &ChatCompletionsResponse{
|
||||
ID: id,
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
contentText += part.Text
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
toolCalls = append(toolCalls, ChatToolCall{
|
||||
ID: item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: item.Name,
|
||||
Arguments: item.Arguments,
|
||||
},
|
||||
})
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
// silently consumed — results already incorporated into text output
|
||||
}
|
||||
}
|
||||
|
||||
msg := ChatMessage{Role: "assistant"}
|
||||
if len(toolCalls) > 0 {
|
||||
msg.ToolCalls = toolCalls
|
||||
}
|
||||
if contentText != "" {
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
out.Choices = []ChatChoice{{
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
}}
|
||||
|
||||
if resp.Usage != nil {
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: resp.Usage.InputTokens,
|
||||
CompletionTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
out.Usage = usage
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "length"
|
||||
}
|
||||
return "stop"
|
||||
case "completed":
|
||||
if len(toolCalls) > 0 {
|
||||
return "tool_calls"
|
||||
}
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||
// SSE events into Chat Completions SSE chunks.
|
||||
type ResponsesEventToChatState struct {
|
||||
ID string
|
||||
Model string
|
||||
Created int64
|
||||
SentRole bool
|
||||
SawToolCall bool
|
||||
SawText bool
|
||||
Finalized bool // true after finish chunk has been emitted
|
||||
NextToolCallIndex int // next sequential tool_call index to assign
|
||||
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||
IncludeUsage bool
|
||||
Usage *ChatUsage
|
||||
}
|
||||
|
||||
// NewResponsesEventToChatState returns an initialised stream state.
|
||||
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||
return &ResponsesEventToChatState{
|
||||
ID: generateChatCmplID(),
|
||||
Created: time.Now().Unix(),
|
||||
OutputIndexToToolIndex: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||
// or more Chat Completions chunks, updating state as it goes.
|
||||
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToChatHandleCreated(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToChatHandleTextDelta(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToChatHandleOutputItemAdded(evt, state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||
// this returns nil.
|
||||
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if state.Finalized {
|
||||
return nil
|
||||
}
|
||||
state.Finalized = true
|
||||
|
||||
finishReason := "stop"
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Response != nil {
|
||||
if evt.Response.ID != "" {
|
||||
state.ID = evt.Response.ID
|
||||
}
|
||||
if state.Model == "" && evt.Response.Model != "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
// Emit the role chunk.
|
||||
if state.SentRole {
|
||||
return nil
|
||||
}
|
||||
state.SentRole = true
|
||||
|
||||
role := "assistant"
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||
}
|
||||
|
||||
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
state.SawText = true
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
}
|
||||
|
||||
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.SawToolCall = true
|
||||
idx := state.NextToolCallIndex
|
||||
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||
state.NextToolCallIndex++
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
ID: evt.Item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: evt.Item.Name,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
Function: ChatFunctionCall{
|
||||
Arguments: evt.Delta,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
state.Finalized = true
|
||||
finishReason := "stop"
|
||||
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
u := evt.Response.Usage
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
state.Usage = usage
|
||||
}
|
||||
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
finishReason = "length"
|
||||
}
|
||||
case "completed":
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
}
|
||||
} else if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
var chunks []ChatCompletionsChunk
|
||||
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: delta,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||
empty := ""
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Content: &empty},
|
||||
FinishReason: &finishReason,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||
func generateChatCmplID() string {
|
||||
b := make([]byte, 12)
|
||||
_, _ = rand.Read(b)
|
||||
return "chatcmpl-" + hex.EncodeToString(b)
|
||||
}
|
||||
@@ -329,6 +329,150 @@ type ResponsesStreamEvent struct {
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Chat Completions API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||
type ChatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||
Tools []ChatTool `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||
|
||||
// Legacy function calling (deprecated but still supported)
|
||||
Functions []ChatFunction `json:"functions,omitempty"`
|
||||
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamOptions configures streaming behavior.
|
||||
type ChatStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatContentPart is a typed content part in a multi-modal message.
|
||||
type ChatContentPart struct {
|
||||
Type string `json:"type"` // "text" | "image_url"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
// ChatImageURL contains the URL for an image content part.
|
||||
type ChatImageURL struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||
}
|
||||
|
||||
// ChatTool describes a tool available to the model.
|
||||
type ChatTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Function *ChatFunction `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// ChatFunction describes a function tool definition.
|
||||
type ChatFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ChatToolCall represents a tool call made by the assistant.
|
||||
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||
type ChatToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"` // "function"
|
||||
Function ChatFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// ChatFunctionCall contains the function name and arguments.
|
||||
type ChatFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||
type ChatCompletionsResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChoice is a single completion choice.
|
||||
type ChatChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||
}
|
||||
|
||||
// ChatUsage holds token counts in Chat Completions format.
|
||||
type ChatUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ChatTokenDetails provides a breakdown of token usage.
|
||||
type ChatTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||
type ChatCompletionsChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion.chunk"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChunkChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||
type ChatChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta ChatDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||
}
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
||||
return []Model{
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package gemini
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byName := make(map[string]Model, len(models))
|
||||
for _, model := range models {
|
||||
byName[model.Name] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"models/gemini-2.5-flash-image",
|
||||
"models/gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, name := range required {
|
||||
model, ok := byName[name]
|
||||
if !ok {
|
||||
t.Fatalf("expected fallback model %q to exist", name)
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) == 0 {
|
||||
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,12 @@ type Model struct {
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
byID := make(map[string]Model, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
byID[model.ID] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, id := range required {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
|
||||
})
|
||||
}
|
||||
|
||||
// Accepted 返回异步接受响应 (HTTP 202)
|
||||
func Accepted(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusAccepted, Response{
|
||||
Code: 0,
|
||||
Message: "accepted",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
|
||||
@@ -3,6 +3,28 @@ package usagestats
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
ModelSourceRequested = "requested"
|
||||
ModelSourceUpstream = "upstream"
|
||||
ModelSourceMapping = "mapping"
|
||||
)
|
||||
|
||||
func IsValidModelSource(source string) bool {
|
||||
switch source {
|
||||
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeModelSource(source string) string {
|
||||
if IsValidModelSource(source) {
|
||||
return source
|
||||
}
|
||||
return ModelSourceRequested
|
||||
}
|
||||
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats struct {
|
||||
// 用户统计
|
||||
@@ -81,6 +103,22 @@ type ModelStat struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// EndpointStat represents usage statistics for a single request endpoint.
|
||||
type EndpointStat struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupUsageSummary represents today's and cumulative cost for a single group.
|
||||
type GroupUsageSummary struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
TodayCost float64 `json:"today_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
type GroupStat struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
@@ -96,12 +134,49 @@ type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserSpendingRankingItem represents a user spending ranking row.
|
||||
type UserSpendingRankingItem struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||
type UserSpendingRankingResponse struct {
|
||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
|
||||
type UserBreakdownItem struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||
type UserBreakdownDimension struct {
|
||||
GroupID int64 // filter by group_id (>0 to enable)
|
||||
Model string // filter by model name (non-empty to enable)
|
||||
ModelType string // "requested", "upstream", or "mapping"
|
||||
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||
EndpointType string // "inbound", "upstream", or "path"
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
@@ -163,15 +238,18 @@ type UsageLogFilters struct {
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
@@ -238,7 +316,9 @@ type AccountUsageSummary struct {
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
Endpoints []EndpointStat `json:"endpoints"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||
}
|
||||
|
||||
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package usagestats
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsValidModelSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source string
|
||||
want bool
|
||||
}{
|
||||
{name: "requested", source: ModelSourceRequested, want: true},
|
||||
{name: "upstream", source: ModelSourceUpstream, want: true},
|
||||
{name: "mapping", source: ModelSourceMapping, want: true},
|
||||
{name: "invalid", source: "foobar", want: false},
|
||||
{name: "empty", source: "", want: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := IsValidModelSource(tc.source); got != tc.want {
|
||||
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModelSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source string
|
||||
want string
|
||||
}{
|
||||
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
|
||||
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
|
||||
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
|
||||
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
|
||||
{name: "empty falls back", source: "", want: ModelSourceRequested},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := NormalizeModelSource(tc.source); got != tc.want {
|
||||
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -50,6 +51,19 @@ type accountRepository struct {
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||
"codex_primary_",
|
||||
"codex_secondary_",
|
||||
"codex_5h_",
|
||||
"codex_7d_",
|
||||
"passive_usage_",
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||
"codex_usage_updated_at": {},
|
||||
"session_window_utilization": {},
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
@@ -384,9 +398,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -460,7 +474,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
if search != "" {
|
||||
q = q.Where(dbaccount.NameContainsFold(search))
|
||||
}
|
||||
if groupID > 0 {
|
||||
if groupID == service.AccountListGroupUngrouped {
|
||||
q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups()))
|
||||
} else if groupID > 0 {
|
||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||
}
|
||||
|
||||
@@ -1185,12 +1201,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
} else {
|
||||
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||
if len(updates) == 0 {
|
||||
return false
|
||||
}
|
||||
for key := range updates {
|
||||
if isSchedulerNeutralExtraKey(key) {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isSchedulerNeutralExtraKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
@@ -1678,8 +1730,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
@@ -1690,31 +1830,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
@@ -1747,12 +1891,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
if s.accounts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.accounts[accountID], nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
if s.accounts == nil {
|
||||
s.accounts = make(map[int64]*service.Account)
|
||||
}
|
||||
if account != nil {
|
||||
s.accounts[account.ID] = account
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
@@ -175,6 +214,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
groupID int64
|
||||
wantCount int
|
||||
validate func(accounts []service.Account)
|
||||
}{
|
||||
@@ -226,6 +266,21 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
s.Require().Contains(accounts[0].Name, "alpha")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_ungrouped",
|
||||
setup: func(client *dbent.Client) {
|
||||
group := mustCreateGroup(s.T(), client, &service.Group{Name: "g-ungrouped"})
|
||||
grouped := mustCreateAccount(s.T(), client, &service.Account{Name: "grouped-account"})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "ungrouped-account"})
|
||||
mustBindAccountToGroup(s.T(), client, grouped.ID, group.ID, 1)
|
||||
},
|
||||
groupID: service.AccountListGroupUngrouped,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal("ungrouped-account", accounts[0].Name)
|
||||
s.Require().Empty(accounts[0].GroupIDs)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -238,7 +293,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
@@ -623,6 +678,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-neutral",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{
|
||||
accounts: map[int64]*service.Account{
|
||||
account.ID: {
|
||||
ID: account.ID,
|
||||
Platform: account.Platform,
|
||||
Status: service.StatusDisabled,
|
||||
Extra: map[string]any{
|
||||
"codex_usage_updated_at": "old",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
updates := map[string]any{
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
"codex_5h_used_percent": 88.5,
|
||||
"session_window_utilization": 0.42,
|
||||
}
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||
|
||||
var outboxCount int
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||
s.Require().Zero(outboxCount)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-codex-exhausted",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||
"codex_7d_reset_after_seconds": 86400,
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(0, count)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-mixed",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, count)
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
|
||||
@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||
query := `
|
||||
UPDATE api_keys
|
||||
SET
|
||||
quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL
|
||||
RETURNING quota_used, quota, key, status
|
||||
`
|
||||
|
||||
state := &service.APIKeyQuotaUsageState{}
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
|
||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||
user := s.mustCreateUser("quota-state@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||
key.Quota = 3
|
||||
key.QuotaUsed = 1
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||
|
||||
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||
s.Require().NotNil(state)
|
||||
s.Require().Equal(3.5, state.QuotaUsed)
|
||||
s.Require().Equal(3.0, state.Quota)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||
s.Require().Equal(key.Key, state.Key)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(3.5, got.QuotaUsed)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
117
backend/internal/repository/backup_s3_store.go
Normal file
117
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -20,6 +20,11 @@ const (
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||
|
||||
// Rate limit window durations — must match service.RateLimitWindow* constants.
|
||||
rateLimitWindow5h = 5 * time.Hour
|
||||
rateLimitWindow1d = 24 * time.Hour
|
||||
rateLimitWindow7d = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
@@ -90,17 +95,40 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
|
||||
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
|
||||
// with window expiration checking. If a window has expired, its usage is reset to cost
|
||||
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
|
||||
// IncrementRateLimitUsage semantics.
|
||||
//
|
||||
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
|
||||
updateRateLimitUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
|
||||
local now = tonumber(ARGV[3])
|
||||
local win5h = tonumber(ARGV[4])
|
||||
local win1d = tonumber(ARGV[5])
|
||||
local win7d = tonumber(ARGV[6])
|
||||
|
||||
-- Helper: check if window is expired and update usage + window accordingly
|
||||
-- Returns nothing, modifies the hash in-place.
|
||||
local function update_window(usage_field, window_field, window_duration)
|
||||
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
|
||||
if w == 0 or (now - w) >= window_duration then
|
||||
-- Window expired or never started: reset usage to cost, start new window
|
||||
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
|
||||
redis.call('HSET', KEYS[1], window_field, tostring(now))
|
||||
else
|
||||
-- Window still valid: accumulate
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
|
||||
end
|
||||
end
|
||||
|
||||
update_window('usage_5h', 'window_5h', win5h)
|
||||
update_window('usage_1d', 'window_1d', win1d)
|
||||
update_window('usage_7d', 'window_7d', win7d)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
@@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data
|
||||
|
||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
|
||||
now := time.Now().Unix()
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
|
||||
cost,
|
||||
int(rateLimitCacheTTL.Seconds()),
|
||||
now,
|
||||
int(rateLimitWindow5h.Seconds()),
|
||||
int(rateLimitWindow1d.Seconds()),
|
||||
int(rateLimitWindow7d.Seconds()),
|
||||
).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||
return err
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user