mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 00:10:21 +08:00
Compare commits
292 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0236b97d49 | ||
|
|
26f6b1eeff | ||
|
|
dc447ccebe | ||
|
|
7ec29638f4 | ||
|
|
4c9562af20 | ||
|
|
71942fd322 | ||
|
|
550b979ac5 | ||
|
|
3878a5a46f | ||
|
|
e443a6a1ea | ||
|
|
963494ec6f | ||
|
|
525cdb8830 | ||
|
|
a6764e82f2 | ||
|
|
8027531d07 | ||
|
|
30706355a4 | ||
|
|
dfe99507b8 | ||
|
|
c1717c9a6c | ||
|
|
1fd1a58a7a | ||
|
|
fad07507be | ||
|
|
a20c211162 | ||
|
|
9f6ab6b817 | ||
|
|
bf3d6c0e6e | ||
|
|
241023f3fc | ||
|
|
1292c44b41 | ||
|
|
b4fce47049 | ||
|
|
e7780cd8c8 | ||
|
|
af96c8ea53 | ||
|
|
7d26b81075 | ||
|
|
b8ada63ac3 | ||
|
|
cfaac12af1 | ||
|
|
6028efd26c | ||
|
|
62a566ef2c | ||
|
|
94419f434c | ||
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
bd9d2671d7 | ||
|
|
62b40636e0 | ||
|
|
eeff451bc5 | ||
|
|
56fcb20f94 | ||
|
|
7134266acf | ||
|
|
2e4ac88ad9 | ||
|
|
51547fa216 | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
aa6047c460 | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 | ||
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
6826149a8f | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
c9debc50b1 | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa | ||
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
b43ee62947 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 |
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
|||||||
# Go 源代码文件
|
# Go 源代码文件
|
||||||
*.go text eol=lf
|
*.go text eol=lf
|
||||||
|
|
||||||
|
# 前端 源代码文件
|
||||||
|
*.ts text eol=lf
|
||||||
|
*.tsx text eol=lf
|
||||||
|
*.js text eol=lf
|
||||||
|
*.jsx text eol=lf
|
||||||
|
*.vue text eol=lf
|
||||||
|
|
||||||
# Shell 脚本
|
# Shell 脚本
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
|||||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -271,3 +271,36 @@ jobs:
|
|||||||
parse_mode: "Markdown",
|
parse_mode: "Markdown",
|
||||||
disable_web_page_preview: true
|
disable_web_page_preview: true
|
||||||
}')"
|
}')"
|
||||||
|
|
||||||
|
sync-version-file:
|
||||||
|
needs: [release]
|
||||||
|
if: ${{ needs.release.result == 'success' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout default branch
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.repository.default_branch }}
|
||||||
|
|
||||||
|
- name: Sync VERSION file to released tag
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
VERSION=${{ github.event.inputs.tag }}
|
||||||
|
VERSION=${VERSION#v}
|
||||||
|
else
|
||||||
|
VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
|
fi
|
||||||
|
|
||||||
|
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
|
||||||
|
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
|
||||||
|
echo "VERSION file already matches $VERSION"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$VERSION" > backend/cmd/server/VERSION
|
||||||
|
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
|
||||||
|
git push origin HEAD:${{ github.event.repository.default_branch }}
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -76,6 +78,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -89,6 +93,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -102,6 +108,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
31
Dockerfile
31
Dockerfile
@@ -9,6 +9,7 @@
|
|||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.21
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
ARG GOSUMDB=sum.golang.google.cn
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|
||||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
|||||||
./cmd/server
|
./cmd/server
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Stage 3: Final Runtime Image
|
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Stage 4: Final Runtime Image
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
FROM ${ALPINE_IMAGE}
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
@@ -86,8 +92,21 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
|||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
|
su-exec \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||||
|
# This ensures version consistency between backup tools and the database server
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
@@ -102,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||||
|
|
||||||
# Switch to non-root user
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
USER sub2api
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
# Expose port (can be overridden by SERVER_PORT env var)
|
# Expose port (can be overridden by SERVER_PORT env var)
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
@@ -112,5 +132,6 @@ EXPOSE 8080
|
|||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
# Run the application
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
@@ -5,7 +5,12 @@
|
|||||||
# It only packages the pre-built binary, no compilation needed.
|
# It only packages the pre-built binary, no compilation needed.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
FROM alpine:3.19
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
|
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||||
LABEL description="Sub2API - AI API Gateway Platform"
|
LABEL description="Sub2API - AI API Gateway Platform"
|
||||||
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
|
|||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
curl \
|
curl \
|
||||||
|
su-exec \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||||
|
# restore work in the runtime container without requiring Docker socket access.
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||||
|
|
||||||
USER sub2api
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
68
README.md
68
README.md
@@ -8,27 +8,31 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||||
|
|
||||||
English | [中文](README_CN.md)
|
English | [中文](README_CN.md)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
Try Sub2API online: **https://demo.sub2api.org/**
|
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||||
|
|
||||||
| Email | Password |
|
| Email | Password |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -39,6 +43,25 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||||
- **Rate Limiting** - Configurable request and token rate limits
|
- **Rate Limiting** - Configurable request and token rate limits
|
||||||
- **Admin Dashboard** - Web interface for monitoring and management
|
- **Admin Dashboard** - Web interface for monitoring and management
|
||||||
|
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||||
|
|
||||||
|
## Don't Want to Self-Host?
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
## Ecosystem
|
||||||
|
|
||||||
|
Community projects that extend or integrate with Sub2API:
|
||||||
|
|
||||||
|
| Project | Description | Features |
|
||||||
|
|---------|-------------|----------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
@@ -51,10 +74,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Documentation
|
## Nginx Reverse Proxy Note
|
||||||
|
|
||||||
- Dependency Security: `docs/dependency-security.md`
|
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -150,14 +178,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
|||||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||||
|
|
||||||
# Start services
|
# Start services
|
||||||
docker-compose -f docker-compose.local.yml up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# View logs
|
# View logs
|
||||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
docker-compose logs -f sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
**What the script does:**
|
**What the script does:**
|
||||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||||
- Creates `.env` file with auto-generated secrets
|
- Creates `.env` file with auto-generated secrets
|
||||||
- Creates data directories (uses local directories for easy backup/migration)
|
- Creates data directories (uses local directories for easy backup/migration)
|
||||||
@@ -522,6 +550,28 @@ sub2api/
|
|||||||
└── install.sh # One-click installation script
|
└── install.sh # One-click installation script
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Disclaimer
|
||||||
|
|
||||||
|
> **Please read carefully before using this project:**
|
||||||
|
>
|
||||||
|
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||||
|
>
|
||||||
|
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
</picture>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT License
|
MIT License
|
||||||
|
|||||||
71
README_CN.md
71
README_CN.md
@@ -8,27 +8,30 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API 网关平台 - 订阅配额分发管理**
|
**AI API 网关平台 - 订阅配额分发管理**
|
||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||||
---
|
---
|
||||||
|
|
||||||
## 在线体验
|
## 在线体验
|
||||||
|
|
||||||
体验地址:**https://v2.pincc.ai/**
|
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||||
|
|
||||||
| 邮箱 | 密码 |
|
| 邮箱 | 密码 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||||
|
|
||||||
## 核心功能
|
## 核心功能
|
||||||
|
|
||||||
@@ -39,6 +42,25 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- **并发控制** - 用户级和账号级并发限制
|
- **并发控制** - 用户级和账号级并发限制
|
||||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||||
- **管理后台** - Web 界面进行监控和管理
|
- **管理后台** - Web 界面进行监控和管理
|
||||||
|
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||||
|
|
||||||
|
## 不想自建?试试官方中转
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
## 生态项目
|
||||||
|
|
||||||
|
围绕 Sub2API 的社区扩展与集成项目:
|
||||||
|
|
||||||
|
| 项目 | 说明 | 功能 |
|
||||||
|
|------|------|------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
@@ -51,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 文档
|
## Nginx 反向代理注意事项
|
||||||
|
|
||||||
- 依赖安全:`docs/dependency-security.md`
|
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||||
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## OpenAI Responses 兼容注意事项
|
|
||||||
|
|
||||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
|
||||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
@@ -154,14 +177,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
|||||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||||
|
|
||||||
# 启动服务
|
# 启动服务
|
||||||
docker-compose -f docker-compose.local.yml up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# 查看日志
|
# 查看日志
|
||||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
docker-compose logs -f sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
**脚本功能:**
|
**脚本功能:**
|
||||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||||
- 创建 `.env` 文件并填充自动生成的密钥
|
- 创建 `.env` 文件并填充自动生成的密钥
|
||||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||||
@@ -588,6 +611,28 @@ sub2api/
|
|||||||
└── install.sh # 一键安装脚本
|
└── install.sh # 一键安装脚本
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 免责声明
|
||||||
|
|
||||||
|
> **使用本项目前请仔细阅读:**
|
||||||
|
>
|
||||||
|
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||||
|
>
|
||||||
|
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
</picture>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
MIT License
|
MIT License
|
||||||
|
|||||||
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -33,7 +33,7 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
// Server layer ProviderSet
|
// Server layer ProviderSet
|
||||||
server.ProviderSet,
|
server.ProviderSet,
|
||||||
|
|
||||||
|
// Privacy client factory for OpenAI training opt-out
|
||||||
|
providePrivacyClientFactory,
|
||||||
|
|
||||||
// BuildInfo provider
|
// BuildInfo provider
|
||||||
provideServiceBuildInfo,
|
provideServiceBuildInfo,
|
||||||
|
|
||||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
|||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
|
backupSvc *service.BackupService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"BackupService", func() error {
|
||||||
|
if backupSvc != nil {
|
||||||
|
backupSvc.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
|
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
@@ -104,11 +105,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||||
@@ -122,6 +123,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
@@ -130,20 +132,26 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
rpmCache := repository.NewRPMCache(redisClient)
|
rpmCache := repository.NewRPMCache(redisClient)
|
||||||
|
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||||
|
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
dataManagementService := service.NewDataManagementService()
|
dataManagementService := service.NewDataManagementService()
|
||||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||||
|
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||||
|
dbDumper := repository.NewPgDumper(configConfig)
|
||||||
|
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||||
|
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
@@ -160,11 +168,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
@@ -199,7 +207,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
@@ -226,11 +234,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -245,6 +253,10 @@ type Application struct {
|
|||||||
Cleanup func()
|
Cleanup func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
@@ -279,6 +291,7 @@ func provideCleanup(
|
|||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
|
backupSvc *service.BackupService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -414,6 +427,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"BackupService", func() error {
|
||||||
|
if backupSvc != nil {
|
||||||
|
backupSvc.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
antigravityOAuthSvc,
|
antigravityOAuthSvc,
|
||||||
nil, // openAIGateway
|
nil, // openAIGateway
|
||||||
nil, // scheduledTestRunner
|
nil, // scheduledTestRunner
|
||||||
|
nil, // backupSvc
|
||||||
)
|
)
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
|
|||||||
@@ -716,6 +716,7 @@ var (
|
|||||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
{Name: "request_id", Type: field.TypeString, Size: 64},
|
{Name: "request_id", Type: field.TypeString, Size: 64},
|
||||||
{Name: "model", Type: field.TypeString, Size: 100},
|
{Name: "model", Type: field.TypeString, Size: 100},
|
||||||
|
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
||||||
@@ -755,31 +756,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -788,32 +789,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -828,17 +829,17 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id_created_at",
|
Name: "usagelog_group_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
|
|||||||
id *int64
|
id *int64
|
||||||
request_id *string
|
request_id *string
|
||||||
model *string
|
model *string
|
||||||
|
upstream_model *string
|
||||||
input_tokens *int
|
input_tokens *int
|
||||||
addinput_tokens *int
|
addinput_tokens *int
|
||||||
output_tokens *int
|
output_tokens *int
|
||||||
@@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
|
|||||||
m.model = nil
|
m.model = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) SetUpstreamModel(s string) {
|
||||||
|
m.upstream_model = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
|
||||||
|
v := m.upstream_model
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.UpstreamModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) ClearUpstreamModel() {
|
||||||
|
m.upstream_model = nil
|
||||||
|
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
|
||||||
|
func (m *UsageLogMutation) UpstreamModelCleared() bool {
|
||||||
|
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetUpstreamModel resets all changes to the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) ResetUpstreamModel() {
|
||||||
|
m.upstream_model = nil
|
||||||
|
delete(m.clearedFields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (m *UsageLogMutation) SetGroupID(i int64) {
|
func (m *UsageLogMutation) SetGroupID(i int64) {
|
||||||
m.group = &i
|
m.group = &i
|
||||||
@@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 32)
|
fields := make([]string, 0, 33)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.model != nil {
|
if m.model != nil {
|
||||||
fields = append(fields, usagelog.FieldModel)
|
fields = append(fields, usagelog.FieldModel)
|
||||||
}
|
}
|
||||||
|
if m.upstream_model != nil {
|
||||||
|
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
if m.group != nil {
|
if m.group != nil {
|
||||||
fields = append(fields, usagelog.FieldGroupID)
|
fields = append(fields, usagelog.FieldGroupID)
|
||||||
}
|
}
|
||||||
@@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.RequestID()
|
return m.RequestID()
|
||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
return m.Model()
|
return m.Model()
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
return m.UpstreamModel()
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
return m.GroupID()
|
return m.GroupID()
|
||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
@@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldRequestID(ctx)
|
return m.OldRequestID(ctx)
|
||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
return m.OldModel(ctx)
|
return m.OldModel(ctx)
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
return m.OldUpstreamModel(ctx)
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
return m.OldGroupID(ctx)
|
return m.OldGroupID(ctx)
|
||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
@@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetModel(v)
|
m.SetModel(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetUpstreamModel(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
v, ok := value.(int64)
|
v, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
|||||||
// mutation.
|
// mutation.
|
||||||
func (m *UsageLogMutation) ClearedFields() []string {
|
func (m *UsageLogMutation) ClearedFields() []string {
|
||||||
var fields []string
|
var fields []string
|
||||||
|
if m.FieldCleared(usagelog.FieldUpstreamModel) {
|
||||||
|
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
if m.FieldCleared(usagelog.FieldGroupID) {
|
if m.FieldCleared(usagelog.FieldGroupID) {
|
||||||
fields = append(fields, usagelog.FieldGroupID)
|
fields = append(fields, usagelog.FieldGroupID)
|
||||||
}
|
}
|
||||||
@@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
|
|||||||
// error if the field is not defined in the schema.
|
// error if the field is not defined in the schema.
|
||||||
func (m *UsageLogMutation) ClearField(name string) error {
|
func (m *UsageLogMutation) ClearField(name string) error {
|
||||||
switch name {
|
switch name {
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
m.ClearUpstreamModel()
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
m.ClearGroupID()
|
m.ClearGroupID()
|
||||||
return nil
|
return nil
|
||||||
@@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
m.ResetModel()
|
m.ResetModel()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
m.ResetUpstreamModel()
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
m.ResetGroupID()
|
m.ResetGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -821,92 +821,96 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
|
||||||
|
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
|
||||||
|
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||||
|
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||||
usagelogDescInputTokens := usagelogFields[7].Descriptor()
|
usagelogDescInputTokens := usagelogFields[8].Descriptor()
|
||||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||||
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
|
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
|
||||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||||
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
|
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
|
||||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||||
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
|
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
|
||||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||||
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
|
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
|
||||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||||
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
|
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
|
||||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||||
usagelogDescInputCost := usagelogFields[13].Descriptor()
|
usagelogDescInputCost := usagelogFields[14].Descriptor()
|
||||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||||
usagelogDescOutputCost := usagelogFields[14].Descriptor()
|
usagelogDescOutputCost := usagelogFields[15].Descriptor()
|
||||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||||
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
|
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
|
||||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||||
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
|
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
|
||||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||||
usagelogDescTotalCost := usagelogFields[17].Descriptor()
|
usagelogDescTotalCost := usagelogFields[18].Descriptor()
|
||||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||||
usagelogDescActualCost := usagelogFields[18].Descriptor()
|
usagelogDescActualCost := usagelogFields[19].Descriptor()
|
||||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||||
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
|
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
|
||||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||||
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
usagelogDescBillingType := usagelogFields[22].Descriptor()
|
||||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||||
// usagelogDescStream is the schema descriptor for stream field.
|
// usagelogDescStream is the schema descriptor for stream field.
|
||||||
usagelogDescStream := usagelogFields[22].Descriptor()
|
usagelogDescStream := usagelogFields[23].Descriptor()
|
||||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||||
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
usagelogDescUserAgent := usagelogFields[26].Descriptor()
|
||||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||||
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
usagelogDescIPAddress := usagelogFields[27].Descriptor()
|
||||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||||
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
usagelogDescImageCount := usagelogFields[28].Descriptor()
|
||||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||||
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
usagelogDescImageSize := usagelogFields[29].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
usagelogDescMediaType := usagelogFields[30].Descriptor()
|
||||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||||
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
|
||||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
field.String("model").
|
field.String("model").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
NotEmpty(),
|
NotEmpty(),
|
||||||
|
// UpstreamModel stores the actual upstream model name when model mapping
|
||||||
|
// is applied. NULL means no mapping — the requested model was used as-is.
|
||||||
|
field.String("upstream_model").
|
||||||
|
MaxLen(100).
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
field.Int64("group_id").
|
field.Int64("group_id").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ type UsageLog struct {
|
|||||||
RequestID string `json:"request_id,omitempty"`
|
RequestID string `json:"request_id,omitempty"`
|
||||||
// Model holds the value of the "model" field.
|
// Model holds the value of the "model" field.
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
// UpstreamModel holds the value of the "upstream_model" field.
|
||||||
|
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||||
// GroupID holds the value of the "group_id" field.
|
// GroupID holds the value of the "group_id" field.
|
||||||
GroupID *int64 `json:"group_id,omitempty"`
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
// SubscriptionID holds the value of the "subscription_id" field.
|
// SubscriptionID holds the value of the "subscription_id" field.
|
||||||
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Model = value.String
|
_m.Model = value.String
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UpstreamModel = new(string)
|
||||||
|
*_m.UpstreamModel = value.String
|
||||||
|
}
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||||
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString("model=")
|
builder.WriteString("model=")
|
||||||
builder.WriteString(_m.Model)
|
builder.WriteString(_m.Model)
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.UpstreamModel; v != nil {
|
||||||
|
builder.WriteString("upstream_model=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
if v := _m.GroupID; v != nil {
|
if v := _m.GroupID; v != nil {
|
||||||
builder.WriteString("group_id=")
|
builder.WriteString("group_id=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ const (
|
|||||||
FieldRequestID = "request_id"
|
FieldRequestID = "request_id"
|
||||||
// FieldModel holds the string denoting the model field in the database.
|
// FieldModel holds the string denoting the model field in the database.
|
||||||
FieldModel = "model"
|
FieldModel = "model"
|
||||||
|
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||||
|
FieldUpstreamModel = "upstream_model"
|
||||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||||
FieldGroupID = "group_id"
|
FieldGroupID = "group_id"
|
||||||
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
||||||
@@ -135,6 +137,7 @@ var Columns = []string{
|
|||||||
FieldAccountID,
|
FieldAccountID,
|
||||||
FieldRequestID,
|
FieldRequestID,
|
||||||
FieldModel,
|
FieldModel,
|
||||||
|
FieldUpstreamModel,
|
||||||
FieldGroupID,
|
FieldGroupID,
|
||||||
FieldSubscriptionID,
|
FieldSubscriptionID,
|
||||||
FieldInputTokens,
|
FieldInputTokens,
|
||||||
@@ -179,6 +182,8 @@ var (
|
|||||||
RequestIDValidator func(string) error
|
RequestIDValidator func(string) error
|
||||||
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
||||||
ModelValidator func(string) error
|
ModelValidator func(string) error
|
||||||
|
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||||
|
UpstreamModelValidator func(string) error
|
||||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||||
DefaultInputTokens int
|
DefaultInputTokens int
|
||||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||||
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByUpstreamModel orders the results by the upstream_model field.
|
||||||
|
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByGroupID orders the results by the group_id field.
|
// ByGroupID orders the results by the group_id field.
|
||||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||||
|
|||||||
@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
|
||||||
|
func UpstreamModel(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||||
func GroupID(v int64) predicate.UsageLog {
|
func GroupID(v int64) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||||
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelGT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelGTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelLT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelLTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelContains(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelIsNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNotNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelEqualFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelContainsFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||||
func GroupIDEQ(v int64) predicate.UsageLog {
|
func GroupIDEQ(v int64) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||||
|
|||||||
@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
|
||||||
|
_c.mutation.SetUpstreamModel(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
||||||
_c.mutation.SetGroupID(v)
|
_c.mutation.SetGroupID(v)
|
||||||
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _c.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.InputTokens(); !ok {
|
if _, ok := _c.mutation.InputTokens(); !ok {
|
||||||
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
||||||
}
|
}
|
||||||
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
_node.Model = value
|
_node.Model = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
_node.UpstreamModel = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.InputTokens(); ok {
|
if value, ok := _c.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
_node.InputTokens = value
|
_node.InputTokens = value
|
||||||
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldUpstreamModel, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldUpstreamModel)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
|
||||||
|
u.SetNull(usagelog.FieldUpstreamModel)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldGroupID, v)
|
u.Set(usagelog.FieldGroupID, v)
|
||||||
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetUpstreamModel(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetUpstreamModel(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
|
||||||
|
_u.mutation.SetUpstreamModel(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
|
||||||
|
_u.mutation.ClearUpstreamModel()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
||||||
_u.mutation.SetGroupID(v)
|
_u.mutation.SetGroupID(v)
|
||||||
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.UserAgent(); ok {
|
if v, ok := _u.mutation.UserAgent(); ok {
|
||||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.Model(); ok {
|
if value, ok := _u.mutation.Model(); ok {
|
||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UpstreamModelCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.InputTokens(); ok {
|
if value, ok := _u.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.SetUpstreamModel(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ClearUpstreamModel()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
||||||
_u.mutation.SetGroupID(v)
|
_u.mutation.SetGroupID(v)
|
||||||
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.UserAgent(); ok {
|
if v, ok := _u.mutation.UserAgent(); ok {
|
||||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if value, ok := _u.mutation.Model(); ok {
|
if value, ok := _u.mutation.Model(); ok {
|
||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UpstreamModelCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.InputTokens(); ok {
|
if value, ok := _u.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ require (
|
|||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||||
github.com/alitto/pond/v2 v2.6.2
|
github.com/alitto/pond/v2 v2.6.2
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||||
@@ -66,7 +66,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||||
github.com/aws/smithy-go v1.24.1 // indirect
|
github.com/aws/smithy-go v1.24.2 // indirect
|
||||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
|||||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||||
@@ -58,8 +58,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
|
|||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||||
@@ -94,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
|||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
|
||||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
|
||||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
|
||||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
@@ -234,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
|
||||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@@ -269,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@@ -322,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
|
||||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
|
|||||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
|||||||
|
|
||||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
type DashboardAggregationRetentionConfig struct {
|
type DashboardAggregationRetentionConfig struct {
|
||||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
HourlyDays int `mapstructure:"hourly_days"`
|
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||||
DailyDays int `mapstructure:"daily_days"`
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageCleanupConfig 使用记录清理任务配置
|
// UsageCleanupConfig 使用记录清理任务配置
|
||||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
}
|
}
|
||||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
|||||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
}
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||||
|
}
|
||||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
}
|
}
|
||||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "dashboard aggregation disabled interval",
|
name: "dashboard aggregation disabled interval",
|
||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
|
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -81,13 +82,15 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
// Claude Haiku → Sonnet(无 Haiku 支持)
|
// Claude Haiku → Sonnet(无 Haiku 支持)
|
||||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
"claude-haiku-4-5": "claude-sonnet-4-6",
|
||||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
|
||||||
// Gemini 2.5 白名单
|
// Gemini 2.5 白名单
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
// Gemini 3 白名单
|
// Gemini 3 白名单
|
||||||
"gemini-3-flash": "gemini-3-flash",
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
@@ -111,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||||
|
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||||
|
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||||
|
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||||
|
var DefaultBedrockModelMapping = map[string]string{
|
||||||
|
// Claude Opus
|
||||||
|
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||||
|
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||||
|
// Claude Sonnet
|
||||||
|
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
// Claude Haiku
|
||||||
|
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
cases := map[string]string{
|
cases := map[string]string{
|
||||||
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enrichCredentialsFromIDToken(&item)
|
||||||
|
|
||||||
accountInput := &service.CreateAccountInput{
|
accountInput := &service.CreateAccountInput{
|
||||||
Name: item.Name,
|
Name: item.Name,
|
||||||
Notes: item.Notes,
|
Notes: item.Notes,
|
||||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||||
|
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||||
|
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||||
|
// Existing credential values are never overwritten — only missing fields are filled.
|
||||||
|
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||||
|
if item.Credentials == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Only enrich OpenAI/Sora OAuth accounts
|
||||||
|
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||||
|
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, _ := item.Credentials["id_token"].(string)
|
||||||
|
if strings.TrimSpace(idToken) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeIDToken skips expiry validation — safe for imported data
|
||||||
|
claims, err := openai.DecodeIDToken(idToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userInfo := claims.GetUserInfo()
|
||||||
|
if userInfo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill missing fields only (never overwrite existing values)
|
||||||
|
setIfMissing := func(key, value string) {
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||||
|
item.Credentials[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setIfMissing("email", userInfo.Email)
|
||||||
|
setIfMissing("plan_type", userInfo.PlanType)
|
||||||
|
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||||
|
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||||
|
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeProxyStatus(status string) string {
|
func normalizeProxyStatus(status string) string {
|
||||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||||
switch normalized {
|
switch normalized {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -18,6 +19,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -95,7 +97,7 @@ type CreateAccountRequest struct {
|
|||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -114,7 +116,7 @@ type CreateAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -163,6 +165,8 @@ type AccountWithConcurrency struct {
|
|||||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||||
|
|
||||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||||
item := AccountWithConcurrency{
|
item := AccountWithConcurrency{
|
||||||
Account: dto.AccountFromService(account),
|
Account: dto.AccountFromService(account),
|
||||||
@@ -224,7 +228,20 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
var groupID int64
|
var groupID int64
|
||||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
if groupIDStr == accountListGroupUngroupedQueryValue {
|
||||||
|
groupID = service.AccountListGroupUngrouped
|
||||||
|
} else {
|
||||||
|
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if parsedGroupID < 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = parsedGroupID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||||
@@ -626,6 +643,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
|||||||
// TestAccountRequest represents the request body for testing an account
|
// TestAccountRequest represents the request body for testing an account
|
||||||
type TestAccountRequest struct {
|
type TestAccountRequest struct {
|
||||||
ModelID string `json:"model_id"`
|
ModelID string `json:"model_id"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncFromCRSRequest struct {
|
type SyncFromCRSRequest struct {
|
||||||
@@ -656,7 +674,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
|||||||
_ = c.ShouldBindJSON(&req)
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
// Use AccountTestService to test the account with SSE streaming
|
// Use AccountTestService to test the account with SSE streaming
|
||||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||||
// Error already sent via SSE, just log
|
// Error already sent via SSE, just log
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -751,52 +769,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh handles refreshing account credentials
|
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||||
// POST /api/v1/admin/accounts/:id/refresh
|
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
response.BadRequest(c, "Invalid account ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get account
|
|
||||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
|
||||||
if err != nil {
|
|
||||||
response.NotFound(c, "Account not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
|
||||||
if !account.IsOAuth() {
|
if !account.IsOAuth() {
|
||||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var newCredentials map[string]any
|
var newCredentials map[string]any
|
||||||
|
|
||||||
if account.IsOpenAI() {
|
if account.IsOpenAI() {
|
||||||
// Use OpenAI OAuth service to refresh token
|
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
return nil, "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build new credentials from token info
|
|
||||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
// Preserve non-token settings from existing credentials
|
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
if _, exists := newCredentials[k]; !exists {
|
if _, exists := newCredentials[k]; !exists {
|
||||||
newCredentials[k] = v
|
newCredentials[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if account.Platform == service.PlatformGemini {
|
} else if account.Platform == service.PlatformGemini {
|
||||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
@@ -806,10 +803,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if account.Platform == service.PlatformAntigravity {
|
} else if account.Platform == service.PlatformAntigravity {
|
||||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
return nil, "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
@@ -828,37 +824,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
|
||||||
if tokenInfo.ProjectIDMissing {
|
if tokenInfo.ProjectIDMissing {
|
||||||
// 先更新凭证(token 本身刷新成功了)
|
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
|
||||||
Credentials: newCredentials,
|
Credentials: newCredentials,
|
||||||
})
|
})
|
||||||
if updateErr != nil {
|
if updateErr != nil {
|
||||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// 不标记为 error,只返回警告信息
|
return updatedAccount, "missing_project_id_temporary", nil
|
||||||
response.Success(c, gin.H{
|
|
||||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
|
||||||
"warning": "missing_project_id_temporary",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Use Anthropic/Claude OAuth service to refresh token
|
// Use Anthropic/Claude OAuth service to refresh token
|
||||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
return nil, "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||||
@@ -880,20 +866,54 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||||
Credentials: newCredentials,
|
Credentials: newCredentials,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||||
|
if h.tokenCacheInvalidator != nil {
|
||||||
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||||
|
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||||
|
|
||||||
|
return updatedAccount, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh handles refreshing account credentials
|
||||||
|
// POST /api/v1/admin/accounts/:id/refresh
|
||||||
|
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||||
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid account ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get account
|
||||||
|
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "Account not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
if warning == "missing_project_id_temporary" {
|
||||||
if h.tokenCacheInvalidator != nil {
|
response.Success(c, gin.H{
|
||||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||||
// 缓存失效失败只记录日志,不影响主流程
|
"warning": "missing_project_id_temporary",
|
||||||
_ = c.Error(invalidateErr)
|
})
|
||||||
}
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||||
@@ -949,14 +969,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
|||||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||||
// 缓存失效失败只记录日志,不影响主流程
|
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||||
_ = c.Error(invalidateErr)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchClearError handles batch clearing account errors
|
||||||
|
// POST /api/v1/admin/accounts/batch-clear-error
|
||||||
|
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.AccountIDs) == 0 {
|
||||||
|
response.BadRequest(c, "account_ids is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
const maxConcurrency = 10
|
||||||
|
g, gctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(maxConcurrency)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var successCount, failedCount int
|
||||||
|
var errors []gin.H
|
||||||
|
|
||||||
|
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||||
|
for _, id := range req.AccountIDs {
|
||||||
|
accountID := id // 闭包捕获
|
||||||
|
g.Go(func() error {
|
||||||
|
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": accountID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除错误后,同时清除 token 缓存
|
||||||
|
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||||
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
successCount++
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"total": len(req.AccountIDs),
|
||||||
|
"success": successCount,
|
||||||
|
"failed": failedCount,
|
||||||
|
"errors": errors,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRefresh handles batch refreshing account credentials
|
||||||
|
// POST /api/v1/admin/accounts/batch-refresh
|
||||||
|
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.AccountIDs) == 0 {
|
||||||
|
response.BadRequest(c, "account_ids is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||||
|
foundIDs := make(map[int64]bool, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if acc != nil {
|
||||||
|
foundIDs[acc.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxConcurrency = 10
|
||||||
|
g, gctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(maxConcurrency)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var successCount, failedCount int
|
||||||
|
var errors []gin.H
|
||||||
|
var warnings []gin.H
|
||||||
|
|
||||||
|
// 将不存在的账号 ID 标记为失败
|
||||||
|
for _, id := range req.AccountIDs {
|
||||||
|
if !foundIDs[id] {
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": id,
|
||||||
|
"error": "account not found",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||||
|
for _, account := range accounts {
|
||||||
|
acc := account // 闭包捕获
|
||||||
|
if acc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
g.Go(func() error {
|
||||||
|
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||||
|
mu.Lock()
|
||||||
|
if err != nil {
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
successCount++
|
||||||
|
if warning != "" {
|
||||||
|
warnings = append(warnings, gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"warning": warning,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"total": len(req.AccountIDs),
|
||||||
|
"success": successCount,
|
||||||
|
"failed": failedCount,
|
||||||
|
"errors": errors,
|
||||||
|
"warnings": warnings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// BatchCreate handles batch creating accounts
|
// BatchCreate handles batch creating accounts
|
||||||
// POST /api/v1/admin/accounts/batch
|
// POST /api/v1/admin/accounts/batch
|
||||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||||
@@ -1330,7 +1511,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUsage handles getting account usage information
|
// GetUsage handles getting account usage information
|
||||||
// GET /api/v1/admin/accounts/:id/usage
|
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1338,7 +1519,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
source := c.DefaultQuery("source", "active")
|
||||||
|
|
||||||
|
var usage *service.UsageInfo
|
||||||
|
if source == "passive" {
|
||||||
|
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||||
|
} else {
|
||||||
|
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -1552,13 +1740,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
|
|
||||||
// Handle OpenAI accounts
|
// Handle OpenAI accounts
|
||||||
if account.IsOpenAI() {
|
if account.IsOpenAI() {
|
||||||
// For OAuth accounts: return default OpenAI models
|
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||||
if account.IsOAuth() {
|
if account.IsOpenAIPassthroughEnabled() {
|
||||||
response.Success(c, openai.DefaultModels)
|
response.Success(c, openai.DefaultModels)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For API Key accounts: check model_mapping
|
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
response.Success(c, openai.DefaultModels)
|
response.Success(c, openai.DefaultModels)
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type availableModelsAdminService struct {
|
||||||
|
*stubAdminService
|
||||||
|
account service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||||
|
if s.account.ID == id {
|
||||||
|
acc := s.account
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||||
|
svc := &availableModelsAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
account: service.Account{
|
||||||
|
ID: 42,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := setupAvailableModelsRouter(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Len(t, resp.Data, 1)
|
||||||
|
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||||
|
svc := &availableModelsAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
account: service.Account{
|
||||||
|
ID: 43,
|
||||||
|
Name: "openai-oauth-passthrough",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_passthrough": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := setupAvailableModelsRouter(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.NotEmpty(t, resp.Data)
|
||||||
|
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
|||||||
adminSvc := newStubAdminService()
|
adminSvc := newStubAdminService()
|
||||||
|
|
||||||
userHandler := NewUserHandler(adminSvc, nil)
|
userHandler := NewUserHandler(adminSvc, nil)
|
||||||
groupHandler := NewGroupHandler(adminSvc)
|
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||||
proxyHandler := NewProxyHandler(adminSvc)
|
proxyHandler := NewProxyHandler(adminSvc)
|
||||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
@@ -429,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure stub implements interface.
|
// Ensure stub implements interface.
|
||||||
var _ service.AdminService = (*stubAdminService)(nil)
|
var _ service.AdminService = (*stubAdminService)(nil)
|
||||||
|
|||||||
205
backend/internal/handler/admin/backup_handler.go
Normal file
205
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackupHandler struct {
|
||||||
|
backupService *service.BackupService
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||||
|
return &BackupHandler{
|
||||||
|
backupService: backupService,
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── S3 配置 ───
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||||
|
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||||
|
var req service.BackupS3Config
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||||
|
var req service.BackupS3Config
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 定时备份 ───
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||||
|
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||||
|
var req service.BackupScheduleConfig
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 备份操作 ───
|
||||||
|
|
||||||
|
type CreateBackupRequest struct {
|
||||||
|
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||||
|
var req CreateBackupRequest
|
||||||
|
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||||
|
|
||||||
|
expireDays := 14 // 默认14天过期
|
||||||
|
if req.ExpireDays != nil {
|
||||||
|
expireDays = *req.ExpireDays
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Accepted(c, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||||
|
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if records == nil {
|
||||||
|
records = []service.BackupRecord{}
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"items": records})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"url": url})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||||
|
|
||||||
|
type RestoreBackupRequest struct {
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RestoreBackupRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "password is required for restore operation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从上下文获取当前管理员用户 ID
|
||||||
|
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取管理员用户并验证密码
|
||||||
|
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !user.CheckPassword(req.Password) {
|
||||||
|
response.BadRequest(c, "incorrect admin password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Accepted(c, record)
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -249,11 +250,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -271,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID, accountID, groupID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
modelSource := usagestats.ModelSourceRequested
|
||||||
var requestType *int16
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
var billingType *int8
|
var billingType *int8
|
||||||
@@ -295,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
groupID = id
|
groupID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
|
||||||
|
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||||
|
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelSource = rawModelSource
|
||||||
|
}
|
||||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -321,11 +331,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"models": stats,
|
"models": stats,
|
||||||
@@ -391,11 +402,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get group statistics")
|
response.Error(c, 500, "Failed to get group statistics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"groups": stats,
|
"groups": stats,
|
||||||
@@ -416,11 +428,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
|||||||
limit = 5
|
limit = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage trend")
|
response.Error(c, 500, "Failed to get API key usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -442,11 +455,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
|||||||
limit = 12
|
limit = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
response.Error(c, 500, "Failed to get user usage trend")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"trend": trend,
|
"trend": trend,
|
||||||
@@ -461,9 +475,62 @@ type BatchUsersUsageRequest struct {
|
|||||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
func parseRankingLimit(raw string) int {
|
||||||
|
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||||
|
if err != nil || limit <= 0 {
|
||||||
|
return 12
|
||||||
|
}
|
||||||
|
if limit > 50 {
|
||||||
|
return 50
|
||||||
|
}
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||||
|
// GET /api/v1/admin/dashboard/users-ranking
|
||||||
|
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||||
|
|
||||||
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
Start string `json:"start"`
|
||||||
|
End string `json:"end"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}{
|
||||||
|
Start: startTime.UTC().Format(time.RFC3339),
|
||||||
|
End: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user spending ranking")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"ranking": ranking.Ranking,
|
||||||
|
"total_actual_cost": ranking.TotalActualCost,
|
||||||
|
"total_requests": ranking.TotalRequests,
|
||||||
|
"total_tokens": ranking.TotalTokens,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
}
|
||||||
|
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||||
// POST /api/v1/admin/dashboard/users-usage
|
// POST /api/v1/admin/dashboard/users-usage
|
||||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||||
@@ -546,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
|||||||
c.Header("X-Snapshot-Cache", "miss")
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
response.Success(c, payload)
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||||
|
// GET /api/v1/admin/dashboard/user-breakdown
|
||||||
|
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||||
|
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
|
dim := usagestats.UserBreakdownDimension{}
|
||||||
|
if v := c.Query("group_id"); v != "" {
|
||||||
|
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
dim.GroupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dim.Model = c.Query("model")
|
||||||
|
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
|
||||||
|
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||||
|
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dim.ModelType = rawModelSource
|
||||||
|
dim.Endpoint = c.Query("endpoint")
|
||||||
|
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if v := c.Query("limit"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||||
|
c.Request.Context(), startTime, endTime, dim, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"users": stats,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardUsageRepoCacheProbe struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
trendCalls atomic.Int32
|
||||||
|
usersTrendCalls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
r.trendCalls.Add(1)
|
||||||
|
return []usagestats.TrendDataPoint{{
|
||||||
|
Date: "2026-03-11",
|
||||||
|
Requests: 1,
|
||||||
|
TotalTokens: 2,
|
||||||
|
Cost: 3,
|
||||||
|
ActualCost: 4,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
limit int,
|
||||||
|
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
|
r.usersTrendCalls.Add(1)
|
||||||
|
return []usagestats.UserUsageTrendPoint{{
|
||||||
|
Date: "2026-03-11",
|
||||||
|
UserID: 1,
|
||||||
|
Email: "cache@test.dev",
|
||||||
|
Requests: 2,
|
||||||
|
Tokens: 20,
|
||||||
|
Cost: 2,
|
||||||
|
ActualCost: 1,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetDashboardReadCachesForTest() {
|
||||||
|
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardReadCachesForTest)
|
||||||
|
resetDashboardReadCachesForTest()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &dashboardUsageRepoCacheProbe{}
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec1, req1)
|
||||||
|
require.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardReadCachesForTest)
|
||||||
|
resetDashboardReadCachesForTest()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &dashboardUsageRepoCacheProbe{}
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec1, req1)
|
||||||
|
require.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||||
|
}
|
||||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
|||||||
trendStream *bool
|
trendStream *bool
|
||||||
modelRequestType *int16
|
modelRequestType *int16
|
||||||
modelStream *bool
|
modelStream *bool
|
||||||
|
rankingLimit int
|
||||||
|
ranking []usagestats.UserSpendingRankingItem
|
||||||
|
rankingTotal float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||||
@@ -49,6 +52,20 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
|||||||
return []usagestats.ModelStat{}, nil
|
return []usagestats.ModelStat{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
limit int,
|
||||||
|
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
s.rankingLimit = limit
|
||||||
|
return &usagestats.UserSpendingRankingResponse{
|
||||||
|
Ranking: s.ranking,
|
||||||
|
TotalActualCost: s.rankingTotal,
|
||||||
|
TotalRequests: 44,
|
||||||
|
TotalTokens: 1234,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
@@ -56,6 +73,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
|||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||||
|
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,3 +148,54 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsValidModelSource(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||||
|
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
|
repo := &dashboardUsageRepoCapture{
|
||||||
|
ranking: []usagestats.UserSpendingRankingItem{
|
||||||
|
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||||
|
},
|
||||||
|
rankingTotal: 88.8,
|
||||||
|
}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, 50, repo.rankingLimit)
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||||
|
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,229 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- mock repo ---
|
||||||
|
|
||||||
|
type userBreakdownRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
capturedDim usagestats.UserBreakdownDimension
|
||||||
|
capturedLimit int
|
||||||
|
result []usagestats.UserBreakdownItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||||
|
_ context.Context, _, _ time.Time,
|
||||||
|
dim usagestats.UserBreakdownDimension, limit int,
|
||||||
|
) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
r.capturedDim = dim
|
||||||
|
r.capturedLimit = limit
|
||||||
|
if r.result != nil {
|
||||||
|
return r.result, nil
|
||||||
|
}
|
||||||
|
return []usagestats.UserBreakdownItem{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
h := NewDashboardHandler(svc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- tests ---
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||||
|
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 100, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
// limit > 200 should fall back to default 50
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{
|
||||||
|
result: []usagestats.UserBreakdownItem{
|
||||||
|
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||||
|
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
StartDate string `json:"start_date"`
|
||||||
|
EndDate string `json:"end_date"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, resp.Code)
|
||||||
|
require.Len(t, resp.Data.Users, 2)
|
||||||
|
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||||
|
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||||
|
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||||
|
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||||
|
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||||
|
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, resp.Data.Users)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
}
|
||||||
203
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
203
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardTrendCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardModelGroupCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
ModelSource string `json:"model_source,omitempty"`
|
||||||
|
RequestType *int16 `json:"request_type"`
|
||||||
|
Stream *bool `json:"stream"`
|
||||||
|
BillingType *int8 `json:"billing_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardEntityTrendCacheKey struct {
|
||||||
|
StartTime string `json:"start_time"`
|
||||||
|
EndTime string `json:"end_time"`
|
||||||
|
Granularity string `json:"granularity"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheStatusValue(hit bool) string {
|
||||||
|
if hit {
|
||||||
|
return "hit"
|
||||||
|
}
|
||||||
|
return "miss"
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshalDashboardCacheKey(value any) string {
|
||||||
|
raw, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||||
|
typed, ok := payload.(T)
|
||||||
|
if !ok {
|
||||||
|
var zero T
|
||||||
|
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||||
|
}
|
||||||
|
return typed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getUsageTrendCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
Model: model,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getModelStatsCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
modelSource string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.ModelStat, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
ModelSource: usagestats.NormalizeModelSource(modelSource),
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||||
|
return stats, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getGroupStatsCached(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.GroupStat, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
UserID: userID,
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
|
BillingType: billingType,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||||
|
return stats, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||||
|
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||||
|
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||||
|
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Granularity: granularity,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||||
|
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, hit, err
|
||||||
|
}
|
||||||
|
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||||
|
return trend, hit, err
|
||||||
|
}
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
cacheKey := string(keyRaw)
|
cacheKey := string(keyRaw)
|
||||||
|
|
||||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||||
if cached.ETag != "" {
|
return h.buildSnapshotV2Response(
|
||||||
c.Header("ETag", cached.ETag)
|
c.Request.Context(),
|
||||||
c.Header("Vary", "If-None-Match")
|
startTime,
|
||||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
endTime,
|
||||||
c.Status(http.StatusNotModified)
|
granularity,
|
||||||
return
|
filters,
|
||||||
}
|
includeStats,
|
||||||
}
|
includeTrend,
|
||||||
c.Header("X-Snapshot-Cache", "hit")
|
includeModels,
|
||||||
response.Success(c, cached.Payload)
|
includeGroups,
|
||||||
|
includeUsersTrend,
|
||||||
|
usersTrendLimit,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cached.ETag != "" {
|
||||||
|
c.Header("ETag", cached.ETag)
|
||||||
|
c.Header("Vary", "If-None-Match")
|
||||||
|
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||||
|
c.Status(http.StatusNotModified)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
filters *dashboardSnapshotV2Filters,
|
||||||
|
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||||
|
usersTrendLimit int,
|
||||||
|
) (*dashboardSnapshotV2Response, error) {
|
||||||
resp := &dashboardSnapshotV2Response{
|
resp := &dashboardSnapshotV2Response{
|
||||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
StartDate: startTime.Format("2006-01-02"),
|
StartDate: startTime.Format("2006-01-02"),
|
||||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeStats {
|
if includeStats {
|
||||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
return nil, errors.New("failed to get dashboard statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Stats = &dashboardSnapshotV2Stats{
|
resp.Stats = &dashboardSnapshotV2Stats{
|
||||||
DashboardStats: *stats,
|
DashboardStats: *stats,
|
||||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeTrend {
|
if includeTrend {
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
trend, _, err := h.getUsageTrendCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
granularity,
|
granularity,
|
||||||
@@ -160,35 +186,34 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
return nil, errors.New("failed to get usage trend")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Trend = trend
|
resp.Trend = trend
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeModels {
|
if includeModels {
|
||||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
models, _, err := h.getModelStatsCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
filters.UserID,
|
filters.UserID,
|
||||||
filters.APIKeyID,
|
filters.APIKeyID,
|
||||||
filters.AccountID,
|
filters.AccountID,
|
||||||
filters.GroupID,
|
filters.GroupID,
|
||||||
|
usagestats.ModelSourceRequested,
|
||||||
filters.RequestType,
|
filters.RequestType,
|
||||||
filters.Stream,
|
filters.Stream,
|
||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
return nil, errors.New("failed to get model statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Models = models
|
resp.Models = models
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeGroups {
|
if includeGroups {
|
||||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
groups, _, err := h.getGroupStatsCached(
|
||||||
c.Request.Context(),
|
ctx,
|
||||||
startTime,
|
startTime,
|
||||||
endTime,
|
endTime,
|
||||||
filters.UserID,
|
filters.UserID,
|
||||||
@@ -200,34 +225,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
|||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get group statistics")
|
return nil, errors.New("failed to get group statistics")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Groups = groups
|
resp.Groups = groups
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeUsersTrend {
|
if includeUsersTrend {
|
||||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||||
c.Request.Context(),
|
|
||||||
startTime,
|
|
||||||
endTime,
|
|
||||||
granularity,
|
|
||||||
usersTrendLimit,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
return nil, errors.New("failed to get user usage trend")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.UsersTrend = usersTrend
|
resp.UsersTrend = usersTrend
|
||||||
}
|
}
|
||||||
|
|
||||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
return resp, nil
|
||||||
if cached.ETag != "" {
|
|
||||||
c.Header("ETag", cached.ETag)
|
|
||||||
c.Header("Vary", "If-None-Match")
|
|
||||||
}
|
|
||||||
c.Header("X-Snapshot-Cache", "miss")
|
|
||||||
response.Success(c, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,27 +17,80 @@ import (
|
|||||||
|
|
||||||
// GroupHandler handles admin group management
|
// GroupHandler handles admin group management
|
||||||
type GroupHandler struct {
|
type GroupHandler struct {
|
||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
|
dashboardService *service.DashboardService
|
||||||
|
groupCapacityService *service.GroupCapacityService
|
||||||
|
}
|
||||||
|
|
||||||
|
type optionalLimitField struct {
|
||||||
|
set bool
|
||||||
|
value *float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||||
|
f.set = true
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(data)
|
||||||
|
if bytes.Equal(trimmed, []byte("null")) {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var number float64
|
||||||
|
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var text string
|
||||||
|
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
number, err = strconv.ParseFloat(text, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||||
|
}
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||||
|
if !f.set {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if f.value != nil {
|
||||||
|
return f.value
|
||||||
|
}
|
||||||
|
zero := 0.0
|
||||||
|
return &zero
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupHandler creates a new admin group handler
|
// NewGroupHandler creates a new admin group handler
|
||||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||||
return &GroupHandler{
|
return &GroupHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
|
dashboardService: dashboardService,
|
||||||
|
groupCapacityService: groupCapacityService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroupRequest represents create group request
|
// CreateGroupRequest represents create group request
|
||||||
type CreateGroupRequest struct {
|
type CreateGroupRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
IsExclusive bool `json:"is_exclusive"`
|
IsExclusive bool `json:"is_exclusive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -62,16 +119,16 @@ type CreateGroupRequest struct {
|
|||||||
|
|
||||||
// UpdateGroupRequest represents update group request
|
// UpdateGroupRequest represents update group request
|
||||||
type UpdateGroupRequest struct {
|
type UpdateGroupRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
IsExclusive *bool `json:"is_exclusive"`
|
IsExclusive *bool `json:"is_exclusive"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -191,9 +248,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -244,9 +301,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -311,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
|||||||
_ = groupID // TODO: implement actual stats
|
_ = groupID // TODO: implement actual stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||||
|
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||||
|
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||||
|
userTZ := c.Query("timezone")
|
||||||
|
now := timezone.NowInUserLocation(userTZ)
|
||||||
|
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||||
|
|
||||||
|
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group usage summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||||
|
// GET /api/v1/admin/groups/capacity-summary
|
||||||
|
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||||
|
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group capacity summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroupAPIKeys handles getting API keys in a group
|
// GetGroupAPIKeys handles getting API keys in a group
|
||||||
// GET /api/v1/admin/groups/:id/api-keys
|
// GET /api/v1/admin/groups/:id/api-keys
|
||||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||||
@@ -335,6 +419,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
|||||||
response.Paginated(c, outKeys, total, page, pageSize)
|
response.Paginated(c, outKeys, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||||
|
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if entries == nil {
|
||||||
|
entries = []service.UserGroupRateEntry{}
|
||||||
|
}
|
||||||
|
response.Success(c, entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||||
|
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||||
|
type BatchSetGroupRateMultipliersRequest struct {
|
||||||
|
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||||
|
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchSetGroupRateMultipliersRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||||
type UpdateSortOrderRequest struct {
|
type UpdateSortOrderRequest struct {
|
||||||
Updates []struct {
|
Updates []struct {
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
Platform: platform,
|
Platform: platform,
|
||||||
Type: "oauth",
|
Type: "oauth",
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
|
Extra: nil,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
|||||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
|||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent",
|
"memory_usage_percent",
|
||||||
"concurrency_queue_depth",
|
"concurrency_queue_depth",
|
||||||
|
"group_available_accounts",
|
||||||
|
"group_available_ratio",
|
||||||
|
"group_rate_limit_ratio",
|
||||||
|
"account_rate_limited_count",
|
||||||
|
"account_error_count",
|
||||||
|
"account_error_ratio",
|
||||||
|
"overload_account_count",
|
||||||
}
|
}
|
||||||
|
|
||||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
|||||||
"error_rate",
|
"error_rate",
|
||||||
"upstream_error_rate",
|
"upstream_error_rate",
|
||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent":
|
"memory_usage_percent",
|
||||||
|
"group_available_ratio",
|
||||||
|
"group_rate_limit_ratio",
|
||||||
|
"account_error_ratio":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||||
|
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||||
type CreateAndRedeemCodeRequest struct {
|
type CreateAndRedeemCodeRequest struct {
|
||||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||||
Value float64 `json:"value" binding:"required,gt=0"`
|
Value float64 `json:"value" binding:"required,gt=0"`
|
||||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||||
Notes string `json:"notes"`
|
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||||
|
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||||
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all redeem codes with pagination
|
// List handles listing all redeem codes with pagination
|
||||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Code = strings.TrimSpace(req.Code)
|
req.Code = strings.TrimSpace(req.Code)
|
||||||
|
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||||
|
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||||
|
if req.Type == "" {
|
||||||
|
req.Type = "balance"
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Type == "subscription" {
|
||||||
|
if req.GroupID == nil {
|
||||||
|
response.BadRequest(c, "group_id is required for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ValidityDays <= 0 {
|
||||||
|
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
Type: req.Type,
|
Type: req.Type,
|
||||||
Value: req.Value,
|
Value: req.Value,
|
||||||
Status: service.StatusUnused,
|
Status: service.StatusUnused,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
|
GroupID: req.GroupID,
|
||||||
|
ValidityDays: req.ValidityDays,
|
||||||
})
|
})
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||||
|
|||||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||||
|
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||||
|
// parameter-validation layer that runs before any service call.
|
||||||
|
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||||
|
return &RedeemHandler{
|
||||||
|
adminService: newStubAdminService(),
|
||||||
|
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||||
|
// status code. For cases that pass validation and proceed into the service layer,
|
||||||
|
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||||
|
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||||
|
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||||
|
code = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
handler.CreateAndRedeem(c)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||||
|
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||||
|
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-default",
|
||||||
|
"value": 10.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"omitting type should default to balance and pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-no-group",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"validity_days": 30,
|
||||||
|
// group_id 缺失
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
validityDays int
|
||||||
|
}{
|
||||||
|
{"zero", 0},
|
||||||
|
{"negative", -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-bad-days-" + tc.name,
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": tc.validityDays,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-valid",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": 31,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"valid subscription params should pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-no-extras",
|
||||||
|
"type": "balance",
|
||||||
|
"value": 50.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"balance type should not require group_id or validity_days")
|
||||||
|
}
|
||||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
FrontendURL: settings.FrontendURL,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
|
|||||||
|
|
||||||
// 分组隔离
|
// 分组隔离
|
||||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||||
|
|
||||||
|
// Backend Mode
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Frontend URL 验证
|
||||||
|
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||||
|
if req.FrontendURL != "" {
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||||
|
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 自定义菜单项验证
|
// 自定义菜单项验证
|
||||||
const (
|
const (
|
||||||
maxCustomMenuItems = 20
|
maxCustomMenuItems = 20
|
||||||
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
|
FrontendURL: req.FrontendURL,
|
||||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: req.BackendModeEnabled,
|
||||||
OpsMonitoringEnabled: func() bool {
|
OpsMonitoringEnabled: func() bool {
|
||||||
if req.OpsMonitoringEnabled != nil {
|
if req.OpsMonitoringEnabled != nil {
|
||||||
return *req.OpsMonitoringEnabled
|
return *req.OpsMonitoringEnabled
|
||||||
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
|
FrontendURL: updatedSettings.FrontendURL,
|
||||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
TotpEnabled: updatedSettings.TotpEnabled,
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
|
if before.FrontendURL != after.FrontendURL {
|
||||||
|
changed = append(changed, "frontend_url")
|
||||||
|
}
|
||||||
if before.TotpEnabled != after.TotpEnabled {
|
if before.TotpEnabled != after.TotpEnabled {
|
||||||
changed = append(changed, "totp_enabled")
|
changed = append(changed, "totp_enabled")
|
||||||
}
|
}
|
||||||
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||||
}
|
}
|
||||||
|
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||||
|
changed = append(changed, "backend_mode_enabled")
|
||||||
|
}
|
||||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||||
changed = append(changed, "purchase_subscription_enabled")
|
changed = append(changed, "purchase_subscription_enabled")
|
||||||
}
|
}
|
||||||
@@ -952,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOverloadCooldownSettings 获取529过载冷却配置
|
||||||
|
// GET /api/v1/admin/settings/overload-cooldown
|
||||||
|
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.OverloadCooldownSettings{
|
||||||
|
Enabled: settings.Enabled,
|
||||||
|
CooldownMinutes: settings.CooldownMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
|
||||||
|
type UpdateOverloadCooldownSettingsRequest struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CooldownMinutes int `json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOverloadCooldownSettings 更新529过载冷却配置
|
||||||
|
// PUT /api/v1/admin/settings/overload-cooldown
|
||||||
|
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
|
||||||
|
var req UpdateOverloadCooldownSettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.OverloadCooldownSettings{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
CooldownMinutes: req.CooldownMinutes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.OverloadCooldownSettings{
|
||||||
|
Enabled: updatedSettings.Enabled,
|
||||||
|
CooldownMinutes: updatedSettings.CooldownMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// GetStreamTimeoutSettings 获取流超时处理配置
|
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||||
// GET /api/v1/admin/settings/stream-timeout
|
// GET /api/v1/admin/settings/stream-timeout
|
||||||
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||||
@@ -1405,6 +1482,61 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||||
|
// GET /api/v1/admin/settings/beta-policy
|
||||||
|
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||||
|
for i, r := range settings.Rules {
|
||||||
|
rules[i] = dto.BetaPolicyRule(r)
|
||||||
|
}
|
||||||
|
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||||
|
type UpdateBetaPolicySettingsRequest struct {
|
||||||
|
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||||
|
// PUT /api/v1/admin/settings/beta-policy
|
||||||
|
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||||
|
var req UpdateBetaPolicySettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||||
|
for i, r := range req.Rules {
|
||||||
|
rules[i] = service.BetaPolicyRule(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.BetaPolicySettings{Rules: rules}
|
||||||
|
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-fetch to return updated settings
|
||||||
|
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||||
|
for i, r := range updated.Rules {
|
||||||
|
outRules[i] = dto.BetaPolicyRule(r)
|
||||||
|
}
|
||||||
|
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||||
type UpdateStreamTimeoutSettingsRequest struct {
|
type UpdateStreamTimeoutSettingsRequest struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
type snapshotCacheEntry struct {
|
type snapshotCacheEntry struct {
|
||||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
items map[string]snapshotCacheEntry
|
items map[string]snapshotCacheEntry
|
||||||
|
sf singleflight.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
type snapshotCacheLoadResult struct {
|
||||||
|
Entry snapshotCacheEntry
|
||||||
|
Hit bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
|||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||||
|
if load == nil {
|
||||||
|
return snapshotCacheEntry{}, false, nil
|
||||||
|
}
|
||||||
|
if entry, ok := c.Get(key); ok {
|
||||||
|
return entry, true, nil
|
||||||
|
}
|
||||||
|
if c == nil || key == "" {
|
||||||
|
payload, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return snapshotCacheEntry{}, false, err
|
||||||
|
}
|
||||||
|
return c.Set(key, payload), false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||||
|
if entry, ok := c.Get(key); ok {
|
||||||
|
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||||
|
}
|
||||||
|
payload, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return snapshotCacheEntry{}, false, err
|
||||||
|
}
|
||||||
|
result, ok := value.(snapshotCacheLoadResult)
|
||||||
|
if !ok {
|
||||||
|
return snapshotCacheEntry{}, false, nil
|
||||||
|
}
|
||||||
|
return result.Entry, result.Hit, nil
|
||||||
|
}
|
||||||
|
|
||||||
func buildETagFromAny(payload any) string {
|
func buildETagFromAny(payload any) string {
|
||||||
raw, err := json.Marshal(payload)
|
raw, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
|||||||
require.Empty(t, etag)
|
require.Empty(t, etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
var loads atomic.Int32
|
||||||
|
|
||||||
|
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
return map[string]string{"hello": "world"}, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, hit)
|
||||||
|
require.NotEmpty(t, entry.ETag)
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
|
||||||
|
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
return map[string]string{"unexpected": "value"}, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, hit)
|
||||||
|
require.Equal(t, entry.ETag, entry2.ETag)
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||||
|
c := newSnapshotCache(5 * time.Second)
|
||||||
|
var loads atomic.Int32
|
||||||
|
start := make(chan struct{})
|
||||||
|
const callers = 8
|
||||||
|
errCh := make(chan error, callers)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(callers)
|
||||||
|
for range callers {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||||
|
loads.Add(1)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
return "value", nil
|
||||||
|
})
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
|
||||||
|
for err := range errCh {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), loads.Load())
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
|
platform := c.Query("platform")
|
||||||
|
|
||||||
// Parse sorting parameters
|
// Parse sorting parameters
|
||||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||||
|
|
||||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -216,6 +217,38 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||||
|
type ResetSubscriptionQuotaRequest struct {
|
||||||
|
Daily bool `json:"daily"`
|
||||||
|
Weekly bool `json:"weekly"`
|
||||||
|
Monthly bool `json:"monthly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||||
|
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||||
|
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||||
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid subscription ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req ResetSubscriptionQuotaRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||||
|
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||||
|
}
|
||||||
|
|
||||||
// Revoke handles revoking a subscription
|
// Revoke handles revoking a subscription
|
||||||
// DELETE /api/v1/admin/subscriptions/:id
|
// DELETE /api/v1/admin/subscriptions/:id
|
||||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||||
|
|||||||
@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set end time to end of day
|
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
t = t.AddDate(0, 0, 1)
|
||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||||
|
endTime = endTime.AddDate(0, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
period := c.DefaultQuery("period", "today")
|
period := c.DefaultQuery("period", "today")
|
||||||
switch period {
|
switch period {
|
||||||
|
|||||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: only admin can login
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.respondWithTokenPair(c, user)
|
h.respondWithTokenPair(c, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the login session
|
// Get the user (before session deletion so we can check backend mode)
|
||||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
|
||||||
|
|
||||||
// Get the user
|
|
||||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the login session (only after all checks pass)
|
||||||
|
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||||
|
|
||||||
h.respondWithTokenPair(c, user)
|
h.respondWithTokenPair(c, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||||
if frontendBaseURL == "" {
|
if frontendBaseURL == "" {
|
||||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||||
response.InternalError(c, "Password reset is not configured")
|
response.InternalError(c, "Password reset is not configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: block non-admin token refresh
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, RefreshTokenResponse{
|
response.Success(c, RefreshTokenResponse{
|
||||||
AccessToken: tokenPair.AccessToken,
|
AccessToken: result.AccessToken,
|
||||||
RefreshToken: tokenPair.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
ExpiresIn: tokenPair.ExpiresIn,
|
ExpiresIn: result.ExpiresIn,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
|||||||
email = linuxDoSyntheticEmail(subject)
|
email = linuxDoSyntheticEmail(subject)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||||
|
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||||
|
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||||
|
if tokenErr != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fragment := url.Values{}
|
||||||
|
fragment.Set("error", "invitation_required")
|
||||||
|
fragment.Set("pending_oauth_token", pendingToken)
|
||||||
|
fragment.Set("redirect", redirectTo)
|
||||||
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
|
return
|
||||||
|
}
|
||||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
return
|
return
|
||||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
|||||||
redirectWithFragment(c, frontendCallback, fragment)
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type completeLinuxDoOAuthRequest struct {
|
||||||
|
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||||
|
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||||
|
// the invitation code and creating the user account.
|
||||||
|
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||||
|
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||||
|
var req completeLinuxDoOAuthRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"access_token": tokenPair.AccessToken,
|
||||||
|
"refresh_token": tokenPair.RefreshToken,
|
||||||
|
"expires_in": tokenPair.ExpiresIn,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||||
if h != nil && h.settingSvc != nil {
|
if h != nil && h.settingSvc != nil {
|
||||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||||
|
|||||||
@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.MCPXMLInject,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
AccountCount: g.AccountCount,
|
AccountCount: g.AccountCount,
|
||||||
SortOrder: g.SortOrder,
|
ActiveAccountCount: g.ActiveAccountCount,
|
||||||
|
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||||
|
SortOrder: g.SortOrder,
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
@@ -264,8 +266,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||||
if a.Type == service.AccountTypeAPIKey {
|
if a.IsAPIKeyOrBedrock() {
|
||||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||||
out.QuotaLimit = &limit
|
out.QuotaLimit = &limit
|
||||||
used := a.GetQuotaUsed()
|
used := a.GetQuotaUsed()
|
||||||
@@ -281,6 +283,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
used := a.GetQuotaWeeklyUsed()
|
used := a.GetQuotaWeeklyUsed()
|
||||||
out.QuotaWeeklyUsed = &used
|
out.QuotaWeeklyUsed = &used
|
||||||
}
|
}
|
||||||
|
// 固定时间重置配置
|
||||||
|
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||||
|
out.QuotaDailyResetMode = &mode
|
||||||
|
hour := a.GetQuotaDailyResetHour()
|
||||||
|
out.QuotaDailyResetHour = &hour
|
||||||
|
}
|
||||||
|
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||||
|
out.QuotaWeeklyResetMode = &mode
|
||||||
|
day := a.GetQuotaWeeklyResetDay()
|
||||||
|
out.QuotaWeeklyResetDay = &day
|
||||||
|
hour := a.GetQuotaWeeklyResetHour()
|
||||||
|
out.QuotaWeeklyResetHour = &hour
|
||||||
|
}
|
||||||
|
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||||
|
tz := a.GetQuotaResetTimezone()
|
||||||
|
out.QuotaResetTimezone = &tz
|
||||||
|
}
|
||||||
|
if a.Extra != nil {
|
||||||
|
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||||
|
out.QuotaDailyResetAt = &v
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||||
|
out.QuotaWeeklyResetAt = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@@ -496,8 +523,11 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
AccountID: l.AccountID,
|
AccountID: l.AccountID,
|
||||||
RequestID: l.RequestID,
|
RequestID: l.RequestID,
|
||||||
Model: l.Model,
|
Model: l.Model,
|
||||||
|
UpstreamModel: l.UpstreamModel,
|
||||||
ServiceTier: l.ServiceTier,
|
ServiceTier: l.ServiceTier,
|
||||||
ReasoningEffort: l.ReasoningEffort,
|
ReasoningEffort: l.ReasoningEffort,
|
||||||
|
InboundEndpoint: l.InboundEndpoint,
|
||||||
|
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||||
GroupID: l.GroupID,
|
GroupID: l.GroupID,
|
||||||
SubscriptionID: l.SubscriptionID,
|
SubscriptionID: l.SubscriptionID,
|
||||||
InputTokens: l.InputTokens,
|
InputTokens: l.InputTokens,
|
||||||
|
|||||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
serviceTier := "priority"
|
serviceTier := "priority"
|
||||||
|
inboundEndpoint := "/v1/chat/completions"
|
||||||
|
upstreamEndpoint := "/v1/responses"
|
||||||
log := &service.UsageLog{
|
log := &service.UsageLog{
|
||||||
RequestID: "req_3",
|
RequestID: "req_3",
|
||||||
Model: "gpt-5.4",
|
Model: "gpt-5.4",
|
||||||
ServiceTier: &serviceTier,
|
ServiceTier: &serviceTier,
|
||||||
|
InboundEndpoint: &inboundEndpoint,
|
||||||
|
UpstreamEndpoint: &upstreamEndpoint,
|
||||||
AccountRateMultiplier: f64Ptr(1.5),
|
AccountRateMultiplier: f64Ptr(1.5),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
|
|
||||||
require.NotNil(t, userDTO.ServiceTier)
|
require.NotNil(t, userDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||||
|
require.NotNil(t, userDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.ServiceTier)
|
require.NotNil(t, adminDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||||
|
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
@@ -81,6 +82,9 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
// 分组隔离
|
// 分组隔离
|
||||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||||
|
|
||||||
|
// Backend Mode
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultSubscriptionSetting struct {
|
type DefaultSubscriptionSetting struct {
|
||||||
@@ -111,6 +115,7 @@ type PublicSettings struct {
|
|||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
|
|||||||
Items []SoraS3Profile `json:"items"`
|
Items []SoraS3Profile `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||||
|
type OverloadCooldownSettings struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CooldownMinutes int `json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
type StreamTimeoutSettings struct {
|
type StreamTimeoutSettings struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
@@ -168,6 +179,19 @@ type RectifierSettings struct {
|
|||||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BetaPolicyRule Beta 策略规则 DTO
|
||||||
|
type BetaPolicyRule struct {
|
||||||
|
BetaToken string `json:"beta_token"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BetaPolicySettings Beta 策略配置 DTO
|
||||||
|
type BetaPolicySettings struct {
|
||||||
|
Rules []BetaPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||||
// Returns empty slice on empty/invalid input.
|
// Returns empty slice on empty/invalid input.
|
||||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||||
|
|||||||
@@ -122,9 +122,11 @@ type AdminGroup struct {
|
|||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
AccountCount int64 `json:"account_count,omitempty"`
|
AccountCount int64 `json:"account_count,omitempty"`
|
||||||
|
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||||
|
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||||
|
|
||||||
// 分组排序
|
// 分组排序
|
||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
@@ -203,6 +205,16 @@ type Account struct {
|
|||||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||||
|
|
||||||
|
// 配额固定时间重置配置
|
||||||
|
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||||
|
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||||
|
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||||
|
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||||
|
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||||
|
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||||
|
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||||
|
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||||
|
|
||||||
Proxy *Proxy `json:"proxy,omitempty"`
|
Proxy *Proxy `json:"proxy,omitempty"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
|
||||||
@@ -322,11 +334,18 @@ type UsageLog struct {
|
|||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||||
|
// Omitted when no mapping was applied (requested model was used as-is).
|
||||||
|
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||||
ServiceTier *string `json:"service_tier,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty"`
|
||||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
// ReasoningEffort is the request's reasoning effort level.
|
||||||
// nil means not provided / not applicable.
|
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||||
|
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||||
|
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||||
|
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||||
|
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||||
|
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
SubscriptionID *int64 `json:"subscription_id"`
|
SubscriptionID *int64 `json:"subscription_id"`
|
||||||
|
|||||||
174
backend/internal/handler/endpoint.go
Normal file
174
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Canonical inbound / upstream endpoint paths.
|
||||||
|
// All normalization and derivation reference this single set
|
||||||
|
// of constants — add new paths HERE when a new API surface
|
||||||
|
// is introduced.
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointMessages = "/v1/messages"
|
||||||
|
EndpointChatCompletions = "/v1/chat/completions"
|
||||||
|
EndpointResponses = "/v1/responses"
|
||||||
|
EndpointGeminiModels = "/v1beta/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// gin.Context keys used by the middleware and helpers below.
|
||||||
|
const (
|
||||||
|
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Normalization functions
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||||
|
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||||
|
//
|
||||||
|
// "/antigravity/v1/messages" → "/v1/messages"
|
||||||
|
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||||
|
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||||
|
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||||
|
func NormalizeInboundEndpoint(path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(path, EndpointChatCompletions):
|
||||||
|
return EndpointChatCompletions
|
||||||
|
case strings.Contains(path, EndpointMessages):
|
||||||
|
return EndpointMessages
|
||||||
|
case strings.Contains(path, EndpointResponses):
|
||||||
|
return EndpointResponses
|
||||||
|
case strings.Contains(path, EndpointGeminiModels):
|
||||||
|
return EndpointGeminiModels
|
||||||
|
default:
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||||
|
// account platform and the normalized inbound endpoint.
|
||||||
|
//
|
||||||
|
// Platform-specific rules:
|
||||||
|
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||||
|
// such as /v1/responses/compact preserved from the raw URL).
|
||||||
|
// - Anthropic → /v1/messages
|
||||||
|
// - Gemini → /v1beta/models
|
||||||
|
// - Sora → /v1/chat/completions
|
||||||
|
// - Antigravity routes may target either Claude or Gemini, so the
|
||||||
|
// inbound endpoint is used to distinguish.
|
||||||
|
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||||
|
inbound = strings.TrimSpace(inbound)
|
||||||
|
|
||||||
|
switch platform {
|
||||||
|
case service.PlatformOpenAI:
|
||||||
|
// OpenAI forwards everything to the Responses API.
|
||||||
|
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||||
|
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||||
|
return EndpointResponses + suffix
|
||||||
|
}
|
||||||
|
return EndpointResponses
|
||||||
|
|
||||||
|
case service.PlatformAnthropic:
|
||||||
|
return EndpointMessages
|
||||||
|
|
||||||
|
case service.PlatformGemini:
|
||||||
|
return EndpointGeminiModels
|
||||||
|
|
||||||
|
case service.PlatformSora:
|
||||||
|
return EndpointChatCompletions
|
||||||
|
|
||||||
|
case service.PlatformAntigravity:
|
||||||
|
// Antigravity accounts serve both Claude and Gemini.
|
||||||
|
if inbound == EndpointGeminiModels {
|
||||||
|
return EndpointGeminiModels
|
||||||
|
}
|
||||||
|
return EndpointMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown platform — fall back to inbound.
|
||||||
|
return inbound
|
||||||
|
}
|
||||||
|
|
||||||
|
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||||
|
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||||
|
// Returns "" when there is no meaningful suffix.
|
||||||
|
func responsesSubpathSuffix(rawPath string) string {
|
||||||
|
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||||
|
idx := strings.LastIndex(trimmed, "/responses")
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
suffix := trimmed[idx+len("/responses"):]
|
||||||
|
if suffix == "" || suffix == "/" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(suffix, "/") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Middleware
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||||
|
// canonical inbound endpoint in gin.Context so that every handler in
|
||||||
|
// the chain can read it via GetInboundEndpoint.
|
||||||
|
//
|
||||||
|
// Apply this middleware to all gateway route groups.
|
||||||
|
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
path := c.FullPath()
|
||||||
|
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
path = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Context helpers — used by handlers before building
|
||||||
|
// RecordUsageInput / RecordUsageLongContextInput.
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||||
|
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||||
|
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||||
|
func GetInboundEndpoint(c *gin.Context) string {
|
||||||
|
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: normalize on the fly.
|
||||||
|
path := ""
|
||||||
|
if c != nil {
|
||||||
|
path = c.FullPath()
|
||||||
|
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
path = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NormalizeInboundEndpoint(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||||
|
// and the account platform. Handlers call this after scheduling an
|
||||||
|
// account, passing account.Platform.
|
||||||
|
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||||
|
inbound := GetInboundEndpoint(c)
|
||||||
|
rawPath := ""
|
||||||
|
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||||
|
rawPath = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||||
|
}
|
||||||
159
backend/internal/handler/endpoint_test.go
Normal file
159
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() { gin.SetMode(gin.TestMode) }
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// NormalizeInboundEndpoint
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Direct canonical paths.
|
||||||
|
{"/v1/messages", EndpointMessages},
|
||||||
|
{"/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/v1/responses", EndpointResponses},
|
||||||
|
{"/v1beta/models", EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Prefixed paths (antigravity, openai, sora).
|
||||||
|
{"/antigravity/v1/messages", EndpointMessages},
|
||||||
|
{"/openai/v1/responses", EndpointResponses},
|
||||||
|
{"/openai/v1/responses/compact", EndpointResponses},
|
||||||
|
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Gin route patterns with wildcards.
|
||||||
|
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||||
|
{"/v1/responses/*subpath", EndpointResponses},
|
||||||
|
|
||||||
|
// Unknown path is returned as-is.
|
||||||
|
{"/v1/embeddings", "/v1/embeddings"},
|
||||||
|
{"", ""},
|
||||||
|
{" /v1/messages ", EndpointMessages},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.path, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// DeriveUpstreamEndpoint
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inbound string
|
||||||
|
rawPath string
|
||||||
|
platform string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Anthropic.
|
||||||
|
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||||
|
|
||||||
|
// Gemini.
|
||||||
|
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Sora.
|
||||||
|
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||||
|
|
||||||
|
// OpenAI — always /v1/responses.
|
||||||
|
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||||
|
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||||
|
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
|
||||||
|
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||||
|
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||||
|
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Unknown platform — passthrough.
|
||||||
|
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// responsesSubpathSuffix
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
raw string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/v1/responses", ""},
|
||||||
|
{"/v1/responses/", ""},
|
||||||
|
{"/v1/responses/compact", "/compact"},
|
||||||
|
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||||
|
{"/v1/messages", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.raw, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// InboundEndpointMiddleware + context helpers
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(InboundEndpointMiddleware())
|
||||||
|
|
||||||
|
var captured string
|
||||||
|
router.POST("/v1/messages", func(c *gin.Context) {
|
||||||
|
captured = GetInboundEndpoint(c)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, EndpointMessages, captured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||||
|
|
||||||
|
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||||
|
got := GetInboundEndpoint(c)
|
||||||
|
require.Equal(t, EndpointMessages, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||||
|
|
||||||
|
// Simulate middleware.
|
||||||
|
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||||
|
|
||||||
|
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||||
|
require.Equal(t, "/v1/responses/compact", got)
|
||||||
|
}
|
||||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -434,19 +441,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: inboundEndpoint,
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
UserAgent: userAgent,
|
||||||
APIKeyService: h.apiKeyService,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
@@ -635,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -652,6 +671,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Beta policy block: return 400 immediately, no failover
|
||||||
|
var betaBlockedErr *service.BetaBlockedError
|
||||||
|
if errors.As(err, &betaBlockedErr) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var promptTooLongErr *service.PromptTooLongError
|
var promptTooLongErr *service.PromptTooLongError
|
||||||
if errors.As(err, &promptTooLongErr) {
|
if errors.As(err, &promptTooLongErr) {
|
||||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||||
@@ -697,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -729,19 +760,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: currentAPIKey,
|
APIKey: currentAPIKey,
|
||||||
User: currentAPIKey.User,
|
User: currentAPIKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: currentSubscription,
|
Subscription: currentSubscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: inboundEndpoint,
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
UserAgent: userAgent,
|
||||||
APIKeyService: h.apiKeyService,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
@@ -902,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
|||||||
}
|
}
|
||||||
if s := c.Query("end_date"); s != "" {
|
if s := c.Query("end_date"); s != "" {
|
||||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return startTime, endTime
|
return startTime, endTime
|
||||||
@@ -1178,6 +1219,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1186,6 +1231,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||||
|
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||||
|
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||||
|
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||||
|
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||||
|
// 具体验证:
|
||||||
|
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||||
|
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||||
|
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||||
|
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||||
|
|
||||||
|
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||||
|
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||||
|
|
||||||
|
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||||
|
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||||
|
|
||||||
|
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||||
|
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||||
|
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||||
|
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||||
|
|
||||||
|
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx,
|
||||||
|
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||||
|
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||||
|
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||||
|
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||||
|
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
require.Contains(t, body, "event: message_start")
|
||||||
|
require.Contains(t, body, `"type":"error"`)
|
||||||
|
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||||
|
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||||
|
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||||
|
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||||
|
// c.Writer.Size() 仍为 -1
|
||||||
|
|
||||||
|
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||||
|
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||||
|
require.False(t, guardTriggered,
|
||||||
|
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||||
|
}
|
||||||
@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
|
||||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||||
|
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||||
|
|
||||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
@@ -138,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||||
&fakeGroupRepo{group: group},
|
&fakeGroupRepo{group: group},
|
||||||
nil, // usageLogRepo
|
nil, // usageLogRepo
|
||||||
|
nil, // usageBillingRepo
|
||||||
nil, // userRepo
|
nil, // userRepo
|
||||||
nil, // userSubRepo
|
nil, // userSubRepo
|
||||||
nil, // userGroupRateRepo
|
nil, // userGroupRateRepo
|
||||||
|
|||||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||||
cache := &concurrencyCacheMock{
|
cache := &concurrencyCacheMock{
|
||||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
|||||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -132,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
|||||||
return []byte(`{
|
return []byte(`{
|
||||||
"model":"claude-3-5-sonnet-20241022",
|
"model":"claude-3-5-sonnet-20241022",
|
||||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||||
}`)
|
}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
System: []any{
|
System: []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||||
}
|
}
|
||||||
|
|
||||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||||
@@ -205,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
"system": []any{
|
"system": []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||||
})
|
})
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||||
|
|||||||
@@ -503,6 +503,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
@@ -510,8 +513,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
@@ -587,6 +593,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, message := mapGeminiUpstreamError(statusCode)
|
status, message := mapGeminiUpstreamError(statusCode)
|
||||||
googleError(c, status, message)
|
googleError(c, status, message)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
|||||||
Account *admin.AccountHandler
|
Account *admin.AccountHandler
|
||||||
Announcement *admin.AnnouncementHandler
|
Announcement *admin.AnnouncementHandler
|
||||||
DataManagement *admin.DataManagementHandler
|
DataManagement *admin.DataManagementHandler
|
||||||
|
Backup *admin.BackupHandler
|
||||||
OAuth *admin.OAuthHandler
|
OAuth *admin.OAuthHandler
|
||||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||||
GeminiOAuth *admin.GeminiOAuthHandler
|
GeminiOAuth *admin.GeminiOAuthHandler
|
||||||
|
|||||||
286
backend/internal/handler/openai_chat_completions.go
Normal file
286
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||||
|
// POST /v1/chat/completions
|
||||||
|
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||||
|
streamStarted := false
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog := requestLogger(
|
||||||
|
c,
|
||||||
|
"handler.openai_gateway.chat_completions",
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqModel := modelResult.String()
|
||||||
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
|
||||||
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||||
|
|
||||||
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
if h.errorPassthroughService != nil {
|
||||||
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
|
routingStart := time.Now()
|
||||||
|
|
||||||
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
defer userReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||||
|
status, code, message := billingErrorDetails(err)
|
||||||
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||||
|
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||||
|
|
||||||
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
|
switchCount := 0
|
||||||
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
sameAccountRetryCount := make(map[int64]int)
|
||||||
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.Set("openai_chat_completions_fallback_model", "")
|
||||||
|
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
reqModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
|
)
|
||||||
|
if len(failedAccountIDs) == 0 {
|
||||||
|
defaultModel := ""
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
defaultModel = apiKey.Group.DefaultMappedModel
|
||||||
|
}
|
||||||
|
if defaultModel != "" && defaultModel != reqModel {
|
||||||
|
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||||
|
zap.String("default_mapped_model", defaultModel),
|
||||||
|
)
|
||||||
|
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
defaultModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err == nil && selection != nil {
|
||||||
|
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if lastFailoverErr != nil {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||||
|
} else {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account := selection.Account
|
||||||
|
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||||
|
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||||
|
_ = scheduleDecision
|
||||||
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
|
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
|
forwardStart := time.Now()
|
||||||
|
|
||||||
|
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||||
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
responseLatencyMs := forwardDurationMs
|
||||||
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||||
|
}
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||||
|
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
// Pool mode: retry on the same account
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
retryLimit := account.GetPoolModeRetryCount()
|
||||||
|
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||||
|
sameAccountRetryCount[account.ID]++
|
||||||
|
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("retry_limit", retryLimit),
|
||||||
|
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||||
|
)
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
return
|
||||||
|
case <-time.After(sameAccountRetryDelay):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
|
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
} else {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
}); err != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
zap.String("model", reqModel),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
reqLog.Debug("openai_chat_completions.request_completed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||||
|
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||||
|
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||||
|
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "responses root maps to responses upstream",
|
||||||
|
path: "/v1/responses",
|
||||||
|
want: EndpointResponses,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses compact keeps compact suffix",
|
||||||
|
path: "/openai/v1/responses/compact",
|
||||||
|
want: "/v1/responses/compact",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses nested suffix preserved",
|
||||||
|
path: "/openai/v1/responses/compact/detail",
|
||||||
|
want: "/v1/responses/compact/detail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non responses path uses platform fallback",
|
||||||
|
path: "/v1/messages",
|
||||||
|
want: EndpointResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||||
|
|
||||||
|
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -352,18 +352,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.responses"),
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
@@ -653,14 +657,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := ""
|
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||||
if apiKey.Group != nil {
|
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||||
}
|
|
||||||
// 如果使用了降级模型调度,强制使用降级模型
|
|
||||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
|
||||||
defaultMappedModel = fallbackModel
|
|
||||||
}
|
|
||||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
@@ -732,17 +731,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.messages"),
|
zap.String("component", "handler.openai_gateway.messages"),
|
||||||
@@ -1231,14 +1234,17 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
@@ -1429,6 +1435,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1437,6 +1447,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,11 +26,28 @@ const (
|
|||||||
opsStreamKey = "ops_stream"
|
opsStreamKey = "ops_stream"
|
||||||
opsRequestBodyKey = "ops_request_body"
|
opsRequestBodyKey = "ops_request_body"
|
||||||
opsAccountIDKey = "ops_account_id"
|
opsAccountIDKey = "ops_account_id"
|
||||||
|
|
||||||
|
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||||
|
opsErrContextCanceled = "context canceled"
|
||||||
|
opsErrNoAvailableAccounts = "no available accounts"
|
||||||
|
opsErrInvalidAPIKey = "invalid_api_key"
|
||||||
|
opsErrAPIKeyRequired = "api_key_required"
|
||||||
|
opsErrInsufficientBalance = "insufficient balance"
|
||||||
|
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||||
|
opsErrInsufficientQuota = "insufficient_quota"
|
||||||
|
|
||||||
|
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||||
|
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||||
|
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||||
|
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||||
|
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||||
|
opsCodeUserInactive = "USER_INACTIVE"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
opsErrorLogTimeout = 5 * time.Second
|
opsErrorLogTimeout = 5 * time.Second
|
||||||
opsErrorLogDrainTimeout = 10 * time.Second
|
opsErrorLogDrainTimeout = 10 * time.Second
|
||||||
|
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||||
|
|
||||||
opsErrorLogMinWorkerCount = 4
|
opsErrorLogMinWorkerCount = 4
|
||||||
opsErrorLogMaxWorkerCount = 32
|
opsErrorLogMaxWorkerCount = 32
|
||||||
@@ -38,6 +55,7 @@ const (
|
|||||||
opsErrorLogQueueSizePerWorker = 128
|
opsErrorLogQueueSizePerWorker = 128
|
||||||
opsErrorLogMinQueueSize = 256
|
opsErrorLogMinQueueSize = 256
|
||||||
opsErrorLogMaxQueueSize = 8192
|
opsErrorLogMaxQueueSize = 8192
|
||||||
|
opsErrorLogBatchSize = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
type opsErrorLogJob struct {
|
type opsErrorLogJob struct {
|
||||||
@@ -82,27 +100,82 @@ func startOpsErrorLogWorkers() {
|
|||||||
for i := 0; i < workerCount; i++ {
|
for i := 0; i < workerCount; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer opsErrorLogWorkersWg.Done()
|
defer opsErrorLogWorkersWg.Done()
|
||||||
for job := range opsErrorLogQueue {
|
for {
|
||||||
opsErrorLogQueueLen.Add(-1)
|
job, ok := <-opsErrorLogQueue
|
||||||
if job.ops == nil || job.entry == nil {
|
if !ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
func() {
|
opsErrorLogQueueLen.Add(-1)
|
||||||
defer func() {
|
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||||
if r := recover(); r != nil {
|
batch = append(batch, job)
|
||||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
|
||||||
|
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||||
|
batchLoop:
|
||||||
|
for len(batch) < opsErrorLogBatchSize {
|
||||||
|
select {
|
||||||
|
case nextJob, ok := <-opsErrorLogQueue:
|
||||||
|
if !ok {
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushOpsErrorLogBatch(batch)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
opsErrorLogQueueLen.Add(-1)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
batch = append(batch, nextJob)
|
||||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
case <-timer.C:
|
||||||
cancel()
|
break batchLoop
|
||||||
opsErrorLogProcessed.Add(1)
|
}
|
||||||
}()
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushOpsErrorLogBatch(batch)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||||
|
var processed int64
|
||||||
|
for _, job := range batch {
|
||||||
|
if job.ops == nil || job.entry == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||||
|
processed++
|
||||||
|
}
|
||||||
|
if processed == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for opsSvc, entries := range grouped {
|
||||||
|
if opsSvc == nil || len(entries) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||||
|
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
opsErrorLogProcessed.Add(processed)
|
||||||
|
}
|
||||||
|
|
||||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||||
if ops == nil || entry == nil {
|
if ops == nil || entry == nil {
|
||||||
return
|
return
|
||||||
@@ -967,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
|||||||
return errType
|
return errType
|
||||||
}
|
}
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE":
|
case opsCodeInsufficientBalance:
|
||||||
return "billing_error"
|
return "billing_error"
|
||||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "subscription_error"
|
return "subscription_error"
|
||||||
default:
|
default:
|
||||||
return "api_error"
|
return "api_error"
|
||||||
@@ -981,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||||
// Map billing/concurrency/response => request; scheduling => routing.
|
// Map billing/concurrency/response => request; scheduling => routing.
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "request"
|
return "request"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1000,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
case "upstream_error", "overloaded_error":
|
case "upstream_error", "overloaded_error":
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "api_error":
|
case "api_error":
|
||||||
if strings.Contains(msg, "no available accounts") {
|
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||||
return "routing"
|
return "routing"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "internal"
|
||||||
@@ -1046,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
|||||||
|
|
||||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if phase == "billing" || phase == "concurrency" {
|
if phase == "billing" || phase == "concurrency" {
|
||||||
@@ -1140,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
|||||||
|
|
||||||
// Check if context canceled errors should be ignored (client disconnects)
|
// Check if context canceled errors should be ignored (client disconnects)
|
||||||
if settings.IgnoreContextCanceled {
|
if settings.IgnoreContextCanceled {
|
||||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if "no available accounts" errors should be ignored
|
// Check if "no available accounts" errors should be ignored
|
||||||
if settings.IgnoreNoAvailableAccounts {
|
if settings.IgnoreNoAvailableAccounts {
|
||||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||||
if settings.IgnoreInvalidApiKeyErrors {
|
if settings.IgnoreInvalidApiKeyErrors {
|
||||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if insufficient balance errors should be ignored
|
||||||
|
if settings.IgnoreInsufficientBalanceErrors {
|
||||||
|
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||||
|
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||||
|
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
|||||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||||
return service.NewGatewayService(
|
return service.NewGatewayService(
|
||||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -399,17 +399,23 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: inboundEndpoint,
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||||
@@ -478,6 +484,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
|
|||||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@@ -334,15 +334,32 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
|||||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -431,6 +448,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
testutil.StubGatewayCache{},
|
testutil.StubGatewayCache{},
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set end time to end of day
|
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
t = t.AddDate(0, 0, 1)
|
||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 设置结束时间为当天结束
|
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
endTime = endTime.AddDate(0, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
// 使用 period 参数
|
// 使用 period 参数
|
||||||
period := c.DefaultQuery("period", "today")
|
period := c.DefaultQuery("period", "today")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
|||||||
accountHandler *admin.AccountHandler,
|
accountHandler *admin.AccountHandler,
|
||||||
announcementHandler *admin.AnnouncementHandler,
|
announcementHandler *admin.AnnouncementHandler,
|
||||||
dataManagementHandler *admin.DataManagementHandler,
|
dataManagementHandler *admin.DataManagementHandler,
|
||||||
|
backupHandler *admin.BackupHandler,
|
||||||
oauthHandler *admin.OAuthHandler,
|
oauthHandler *admin.OAuthHandler,
|
||||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
|||||||
Account: accountHandler,
|
Account: accountHandler,
|
||||||
Announcement: announcementHandler,
|
Announcement: announcementHandler,
|
||||||
DataManagement: dataManagementHandler,
|
DataManagement: dataManagementHandler,
|
||||||
|
Backup: backupHandler,
|
||||||
OAuth: oauthHandler,
|
OAuth: oauthHandler,
|
||||||
OpenAIOAuth: openaiOAuthHandler,
|
OpenAIOAuth: openaiOAuthHandler,
|
||||||
GeminiOAuth: geminiOAuthHandler,
|
GeminiOAuth: geminiOAuthHandler,
|
||||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewAccountHandler,
|
admin.NewAccountHandler,
|
||||||
admin.NewAnnouncementHandler,
|
admin.NewAnnouncementHandler,
|
||||||
admin.NewDataManagementHandler,
|
admin.NewDataManagementHandler,
|
||||||
|
admin.NewBackupHandler,
|
||||||
admin.NewOAuthHandler,
|
admin.NewOAuthHandler,
|
||||||
admin.NewOpenAIOAuthHandler,
|
admin.NewOpenAIOAuthHandler,
|
||||||
admin.NewGeminiOAuthHandler,
|
admin.NewGeminiOAuthHandler,
|
||||||
|
|||||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
|||||||
// Antigravity 支持的 Gemini 模型
|
// Antigravity 支持的 Gemini 模型
|
||||||
var geminiModels = []modelDef{
|
var geminiModels = []modelDef{
|
||||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
|
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
|
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
|||||||
|
|
||||||
requiredIDs := []string{
|
requiredIDs := []string{
|
||||||
"claude-opus-4-6-thinking",
|
"claude-opus-4-6-thinking",
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
"gemini-2.5-flash-image-preview",
|
||||||
"gemini-3.1-flash-image",
|
"gemini-3.1-flash-image",
|
||||||
"gemini-3.1-flash-image-preview",
|
"gemini-3.1-flash-image-preview",
|
||||||
"gemini-3-pro-image", // legacy compatibility
|
"gemini-3-pro-image", // legacy compatibility
|
||||||
|
|||||||
@@ -19,6 +19,16 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ForbiddenError 表示上游返回 403 Forbidden
|
||||||
|
type ForbiddenError struct {
|
||||||
|
StatusCode int
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ForbiddenError) Error() string {
|
||||||
|
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||||
|
}
|
||||||
|
|
||||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||||
@@ -114,10 +124,68 @@ type IneligibleTier struct {
|
|||||||
type LoadCodeAssistResponse struct {
|
type LoadCodeAssistResponse struct {
|
||||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
|
||||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
|
||||||
|
type PaidTierInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
|
||||||
|
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 || string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] == '"' {
|
||||||
|
var id string
|
||||||
|
if err := json.Unmarshal(data, &id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.ID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
type alias PaidTierInfo
|
||||||
|
var raw alias
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*p = PaidTierInfo(raw)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableCredit 表示一条 AI Credits 余额记录。
|
||||||
|
type AvailableCredit struct {
|
||||||
|
CreditType string `json:"creditType,omitempty"`
|
||||||
|
CreditAmount string `json:"creditAmount,omitempty"`
|
||||||
|
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmount 将 creditAmount 解析为浮点数。
|
||||||
|
func (c *AvailableCredit) GetAmount() float64 {
|
||||||
|
if c.CreditAmount == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
|
||||||
|
func (c *AvailableCredit) GetMinimumAmount() float64 {
|
||||||
|
if c.MinimumCreditAmountForUsage == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// OnboardUserRequest onboardUser 请求
|
// OnboardUserRequest onboardUser 请求
|
||||||
type OnboardUserRequest struct {
|
type OnboardUserRequest struct {
|
||||||
TierID string `json:"tierId"`
|
TierID string `json:"tierId"`
|
||||||
@@ -147,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
|
||||||
|
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
|
||||||
|
if r.PaidTier == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.PaidTier.AvailableCredits
|
||||||
|
}
|
||||||
|
|
||||||
// Client Antigravity API 客户端
|
// Client Antigravity API 客户端
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
@@ -514,7 +590,20 @@ type ModelQuotaInfo struct {
|
|||||||
|
|
||||||
// ModelInfo 模型信息
|
// ModelInfo 模型信息
|
||||||
type ModelInfo struct {
|
type ModelInfo struct {
|
||||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||||
|
DisplayName string `json:"displayName,omitempty"`
|
||||||
|
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||||
|
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||||
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
|
Recommended *bool `json:"recommended,omitempty"`
|
||||||
|
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecatedModelInfo 废弃模型转发信息
|
||||||
|
type DeprecatedModelInfo struct {
|
||||||
|
NewModelID string `json:"newModelId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||||
@@ -524,7 +613,8 @@ type FetchAvailableModelsRequest struct {
|
|||||||
|
|
||||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||||
type FetchAvailableModelsResponse struct {
|
type FetchAvailableModelsResponse struct {
|
||||||
Models map[string]ModelInfo `json:"models"`
|
Models map[string]ModelInfo `json:"models"`
|
||||||
|
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||||
@@ -573,6 +663,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, nil, &ForbiddenError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Body: string(respBodyBytes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
|
|||||||
func TestGetTier_PaidTier优先(t *testing.T) {
|
func TestGetTier_PaidTier优先(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{
|
resp := &LoadCodeAssistResponse{
|
||||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||||
PaidTier: &TierInfo{ID: "g1-pro-tier"},
|
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
|
||||||
}
|
}
|
||||||
if got := resp.GetTier(); got != "g1-pro-tier" {
|
if got := resp.GetTier(); got != "g1-pro-tier" {
|
||||||
t.Errorf("应返回 paidTier: got %s", got)
|
t.Errorf("应返回 paidTier: got %s", got)
|
||||||
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
|
|||||||
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{
|
resp := &LoadCodeAssistResponse{
|
||||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||||
PaidTier: &TierInfo{ID: ""},
|
PaidTier: &PaidTierInfo{ID: ""},
|
||||||
}
|
}
|
||||||
// paidTier.ID 为空时应回退到 currentTier
|
// paidTier.ID 为空时应回退到 currentTier
|
||||||
if got := resp.GetTier(); got != "free-tier" {
|
if got := resp.GetTier(); got != "free-tier" {
|
||||||
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableCredits(t *testing.T) {
|
||||||
|
resp := &LoadCodeAssistResponse{
|
||||||
|
PaidTier: &PaidTierInfo{
|
||||||
|
ID: "g1-pro-tier",
|
||||||
|
AvailableCredits: []AvailableCredit{
|
||||||
|
{
|
||||||
|
CreditType: "GOOGLE_ONE_AI",
|
||||||
|
CreditAmount: "25",
|
||||||
|
MinimumCreditAmountForUsage: "5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
credits := resp.GetAvailableCredits()
|
||||||
|
if len(credits) != 1 {
|
||||||
|
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
|
||||||
|
}
|
||||||
|
if credits[0].GetAmount() != 25 {
|
||||||
|
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
|
||||||
|
}
|
||||||
|
if credits[0].GetMinimumAmount() != 5 {
|
||||||
|
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetTier_两者都为nil(t *testing.T) {
|
func TestGetTier_两者都为nil(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{}
|
resp := &LoadCodeAssistResponse{}
|
||||||
if got := resp.GetTier(); got != "" {
|
if got := resp.GetTier(); got != "" {
|
||||||
|
|||||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
|||||||
"<|user|>",
|
"<|user|>",
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
"<|end_of_turn|>",
|
"<|end_of_turn|>",
|
||||||
"[DONE]",
|
|
||||||
"\n\nHuman:",
|
"\n\nHuman:",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ const (
|
|||||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||||
var defaultUserAgentVersion = "1.20.4"
|
var defaultUserAgentVersion = "1.20.5"
|
||||||
|
|
||||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
var defaultClientSecret = "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"
|
var defaultClientSecret = "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"
|
||||||
|
|||||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||||
}
|
}
|
||||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||||
}
|
}
|
||||||
if SessionTTL != 30*time.Minute {
|
if SessionTTL != 30*time.Minute {
|
||||||
|
|||||||
@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
|
||||||
var systemBlockFilterPrefixes = []string{
|
|
||||||
"x-anthropic-billing-header",
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
|
||||||
func filterSystemBlockByPrefix(text string) string {
|
|
||||||
for _, prefix := range systemBlockFilterPrefixes {
|
|
||||||
if strings.HasPrefix(text, prefix) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(sysStr, "You are Antigravity") {
|
if strings.Contains(sysStr, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
filtered := filterOpenCodePrompt(sysStr)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(block.Text, "You are Antigravity") {
|
if strings.Contains(block.Text, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
filtered := filterOpenCodePrompt(block.Text)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package antigravity
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||||
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
system json.RawMessage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "system array",
|
||||||
|
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system string",
|
||||||
|
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claudeReq := &ClaudeRequest{
|
||||||
|
Model: "claude-3-5-sonnet-latest",
|
||||||
|
System: tt.system,
|
||||||
|
Messages: []ClaudeMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var req V1InternalRequest
|
||||||
|
require.NoError(t, json.Unmarshal(body, &req))
|
||||||
|
require.NotNil(t, req.Request.SystemInstruction)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, part := range req.Request.SystemInstruction.Parts {
|
||||||
|
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
|||||||
assert.Equal(t, "assistant", items[1].Role)
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
assert.Equal(t, "function_call", items[2].Type)
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||||
|
assert.Empty(t, items[2].ID)
|
||||||
assert.Equal(t, "function_call_output", items[3].Type)
|
assert.Equal(t, "function_call_output", items[3].Type)
|
||||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||||
@@ -1007,3 +1008,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
|||||||
// Should default to image/png when media_type is empty.
|
// Should default to image/png when media_type is empty.
|
||||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// normalizeToolParameters tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNormalizeToolParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input json.RawMessage
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input",
|
||||||
|
input: nil,
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: json.RawMessage(``),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null input",
|
||||||
|
input: json.RawMessage(`null`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object without properties",
|
||||||
|
input: json.RawMessage(`{"type":"object"}`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with properties",
|
||||||
|
input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||||
|
expected: `{"type":"object","properties":{"city":{"type":"string"}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-object type",
|
||||||
|
input: json.RawMessage(`{"type":"string"}`),
|
||||||
|
expected: `{"type":"string"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with additional fields preserved",
|
||||||
|
input: json.RawMessage(`{"type":"object","required":["name"]}`),
|
||||||
|
expected: `{"type":"object","required":["name"],"properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON passthrough",
|
||||||
|
input: json.RawMessage(`not json`),
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := normalizeToolParameters(tt.input)
|
||||||
|
if tt.name == "invalid JSON passthrough" {
|
||||||
|
assert.Equal(t, tt.expected, string(result))
|
||||||
|
} else {
|
||||||
|
assert.JSONEq(t, tt.expected, string(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// Parameters must have "properties" field after normalization.
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.Contains(t, params, "properties")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "simple_tool", Description: "A tool"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.JSONEq(t, `"object"`, string(params["type"]))
|
||||||
|
assert.JSONEq(t, `{}`, string(params["properties"]))
|
||||||
|
}
|
||||||
|
|||||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
|||||||
CallID: fcID,
|
CallID: fcID,
|
||||||
Name: b.Name,
|
Name: b.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
ID: fcID,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,8 +409,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
|||||||
Type: "function",
|
Type: "function",
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Description: t.Description,
|
Description: t.Description,
|
||||||
Parameters: t.InputSchema,
|
Parameters: normalizeToolParameters(t.InputSchema),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||||
|
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||||
|
//
|
||||||
|
// - nil/empty → {"type":"object","properties":{}}
|
||||||
|
// - type=object without properties → adds "properties": {}
|
||||||
|
// - otherwise → returned unchanged
|
||||||
|
func normalizeToolParameters(schema json.RawMessage) json.RawMessage {
|
||||||
|
if len(schema) == 0 || string(schema) == "null" {
|
||||||
|
return json.RawMessage(`{"type":"object","properties":{}}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(schema, &m); err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := m["type"]
|
||||||
|
if string(typ) != `"object"` {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["properties"]; ok {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
m["properties"] = json.RawMessage(`{}`)
|
||||||
|
out, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
810
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
810
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,810 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ChatCompletionsToResponses tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4o", resp.Model)
|
||||||
|
assert.True(t, resp.Stream) // always forced true
|
||||||
|
assert.False(t, *resp.Store)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 1)
|
||||||
|
assert.Equal(t, "user", items[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
assert.Equal(t, "system", items[0].Role)
|
||||||
|
assert.Equal(t, "user", items[1].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ChatToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: "ping",
|
||||||
|
Arguments: `{"host":"example.com"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
ToolCallID: "call_1",
|
||||||
|
Content: json.RawMessage(`"pong"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []ChatTool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: &ChatFunction{
|
||||||
|
Name: "ping",
|
||||||
|
Description: "Ping a host",
|
||||||
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
// user + function_call + function_call_output = 3
|
||||||
|
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
|
||||||
|
// Check function_call item
|
||||||
|
assert.Equal(t, "function_call", items[1].Type)
|
||||||
|
assert.Equal(t, "call_1", items[1].CallID)
|
||||||
|
assert.Empty(t, items[1].ID)
|
||||||
|
assert.Equal(t, "ping", items[1].Name)
|
||||||
|
|
||||||
|
// Check function_call_output item
|
||||||
|
assert.Equal(t, "function_call_output", items[2].Type)
|
||||||
|
assert.Equal(t, "call_1", items[2].CallID)
|
||||||
|
assert.Equal(t, "pong", items[2].Output)
|
||||||
|
|
||||||
|
// Check tools
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||||
|
t.Run("max_tokens", func(t *testing.T) {
|
||||||
|
maxTokens := 100
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
MaxTokens: &maxTokens,
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.MaxOutputTokens)
|
||||||
|
// Below minMaxOutputTokens (128), should be clamped
|
||||||
|
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||||
|
maxTokens := 100
|
||||||
|
maxCompletion := 500
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
MaxTokens: &maxTokens,
|
||||||
|
MaxCompletionTokens: &maxCompletion,
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.MaxOutputTokens)
|
||||||
|
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
ReasoningEffort: "high",
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp.Reasoning)
|
||||||
|
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||||
|
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||||
|
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(content)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 1)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||||
|
require.Len(t, parts, 2)
|
||||||
|
assert.Equal(t, "input_text", parts[0].Type)
|
||||||
|
assert.Equal(t, "Describe this", parts[0].Text)
|
||||||
|
assert.Equal(t, "input_image", parts[1].Type)
|
||||||
|
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
},
|
||||||
|
Functions: []ChatFunction{
|
||||||
|
{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather",
|
||||||
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// tool_choice should be converted
|
||||||
|
require.NotNil(t, resp.ToolChoice)
|
||||||
|
var tc map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
|
assert.Equal(t, "function", tc["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
ServiceTier: "flex",
|
||||||
|
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||||
|
}
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "flex", resp.ServiceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: json.RawMessage(`"Let me call a function."`),
|
||||||
|
ToolCalls: []ChatToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: "do_thing",
|
||||||
|
Arguments: `{}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
// user + assistant message (with text) + function_call
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
assert.Equal(t, "user", items[0].Role)
|
||||||
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
|
assert.Empty(t, items[2].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
assert.Equal(t, "output_text", parts[0].Type)
|
||||||
|
assert.Equal(t, "AB", parts[0].Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
assert.Equal(t, "output_text", parts[0].Type)
|
||||||
|
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||||
|
assert.Contains(t, parts[0].Text, "final answer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ResponsesToChatCompletions tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_123",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "Hello, world!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
assert.Equal(t, "chat.completion", chat.Object)
|
||||||
|
assert.Equal(t, "gpt-4o", chat.Model)
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "Hello, world!", content)
|
||||||
|
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_456",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_xyz",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: `{"city":"NYC"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
msg := chat.Choices[0].Message
|
||||||
|
require.Len(t, msg.ToolCalls, 1)
|
||||||
|
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||||
|
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||||
|
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||||
|
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_789",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "reasoning",
|
||||||
|
Summary: []ResponsesSummary{
|
||||||
|
{Type: "summary_text", Text: "I thought about it."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "The answer is 42."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "The answer is 42.", content)
|
||||||
|
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_inc",
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{
|
||||||
|
{Type: "output_text", Text: "partial..."},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_cache",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 10,
|
||||||
|
TotalTokens: 110,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_ws",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "web_search_call",
|
||||||
|
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.Len(t, chat.Choices, 1)
|
||||||
|
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
|
assert.Equal(t, "search results", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Streaming: ResponsesEventToChatChunks tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
// response.created → role chunk
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
ID: "resp_stream",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||||
|
assert.True(t, state.SentRole)
|
||||||
|
|
||||||
|
// response.output_text.delta → content chunk
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: "Hello",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
|
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 1,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_1",
|
||||||
|
Name: "get_weather",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||||
|
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
assert.Equal(t, "call_1", tc.ID)
|
||||||
|
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index)
|
||||||
|
|
||||||
|
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||||
|
Delta: `{"city":`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||||
|
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||||
|
|
||||||
|
// Add a second function call at output_index=2
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 2,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_2",
|
||||||
|
Name: "get_time",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||||
|
|
||||||
|
// Argument delta for second tool call
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 2,
|
||||||
|
Delta: `{"tz":"UTC"}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||||
|
|
||||||
|
// Argument delta for first tool call (interleaved)
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 1,
|
||||||
|
Delta: `"Tokyo"}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||||
|
require.NotNil(t, tc.Index)
|
||||||
|
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalTokens: 70,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
// finish chunk + usage chunk
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
// First chunk: finish_reason
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Second chunk: usage
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||||
|
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SawToolCall = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.delta",
|
||||||
|
Delta: "Thinking...",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.done",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.delta",
|
||||||
|
Delta: "plan",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: "answer",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
|
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
state.Usage = &ChatUsage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := FinalizeResponsesChatStream(state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
// Finish chunk
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Usage chunk
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||||
|
|
||||||
|
// Idempotent: second call returns nil
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||||
|
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||||
|
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
// Simulate response.completed
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||||
|
|
||||||
|
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatChunkToSSE(t *testing.T) {
|
||||||
|
chunk := ChatCompletionsChunk{
|
||||||
|
ID: "chatcmpl-test",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1700000000,
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Choices: []ChatChunkChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: ChatDelta{Role: "assistant"},
|
||||||
|
FinishReason: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sse, err := ChatChunkToSSE(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, sse, "data: ")
|
||||||
|
assert.Contains(t, sse, "chatcmpl-test")
|
||||||
|
assert.Contains(t, sse, "assistant")
|
||||||
|
assert.True(t, len(sse) > 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stream round-trip test
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||||
|
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||||
|
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||||
|
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
var allChunks []ChatCompletionsChunk
|
||||||
|
|
||||||
|
// 1. response.created
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
|
||||||
|
// 2. text deltas
|
||||||
|
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: text,
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. response.completed
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 4,
|
||||||
|
TotalTokens: 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
allChunks = append(allChunks, chunks...)
|
||||||
|
|
||||||
|
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||||
|
require.Len(t, allChunks, 7)
|
||||||
|
|
||||||
|
// First chunk has role
|
||||||
|
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||||
|
|
||||||
|
// Text chunks
|
||||||
|
var fullText string
|
||||||
|
for i := 1; i <= 4; i++ {
|
||||||
|
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||||
|
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
assert.Equal(t, "Hello, world!", fullText)
|
||||||
|
|
||||||
|
// Finish chunk
|
||||||
|
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||||
|
|
||||||
|
// Usage chunk
|
||||||
|
require.NotNil(t, allChunks[6].Usage)
|
||||||
|
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||||
|
|
||||||
|
// All chunks share the same ID
|
||||||
|
for _, c := range allChunks {
|
||||||
|
assert.Equal(t, "resp_rt", c.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
385
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
385
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||||
|
// Responses API request. The upstream always streams, so Stream is forced to
|
||||||
|
// true. store is always false and reasoning.encrypted_content is always
|
||||||
|
// included so that the response translator has full context.
|
||||||
|
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||||
|
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputJSON, err := json.Marshal(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ResponsesRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Input: inputJSON,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: true, // upstream always streams
|
||||||
|
Include: []string{"reasoning.encrypted_content"},
|
||||||
|
ServiceTier: req.ServiceTier,
|
||||||
|
}
|
||||||
|
|
||||||
|
storeFalse := false
|
||||||
|
out.Store = &storeFalse
|
||||||
|
|
||||||
|
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||||
|
maxTokens := 0
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
maxTokens = *req.MaxTokens
|
||||||
|
}
|
||||||
|
if req.MaxCompletionTokens != nil {
|
||||||
|
maxTokens = *req.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
if maxTokens > 0 {
|
||||||
|
v := maxTokens
|
||||||
|
if v < minMaxOutputTokens {
|
||||||
|
v = minMaxOutputTokens
|
||||||
|
}
|
||||||
|
out.MaxOutputTokens = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||||
|
if req.ReasoningEffort != "" {
|
||||||
|
out.Reasoning = &ResponsesReasoning{
|
||||||
|
Effort: req.ReasoningEffort,
|
||||||
|
Summary: "auto",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools[] and legacy functions[] → ResponsesTool[]
|
||||||
|
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||||
|
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tool_choice: already compatible format — pass through directly.
|
||||||
|
// Legacy function_call needs mapping.
|
||||||
|
if len(req.ToolChoice) > 0 {
|
||||||
|
out.ToolChoice = req.ToolChoice
|
||||||
|
} else if len(req.FunctionCall) > 0 {
|
||||||
|
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||||
|
}
|
||||||
|
out.ToolChoice = tc
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||||
|
// array into a Responses API input items array.
|
||||||
|
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
var out []ResponsesInputItem
|
||||||
|
for _, m := range msgs {
|
||||||
|
items, err := chatMessageToResponsesItems(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, items...)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||||
|
// ResponsesInputItem values.
|
||||||
|
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
switch m.Role {
|
||||||
|
case "system":
|
||||||
|
return chatSystemToResponses(m)
|
||||||
|
case "user":
|
||||||
|
return chatUserToResponses(m)
|
||||||
|
case "assistant":
|
||||||
|
return chatAssistantToResponses(m)
|
||||||
|
case "tool":
|
||||||
|
return chatToolToResponses(m)
|
||||||
|
case "function":
|
||||||
|
return chatFunctionToResponses(m)
|
||||||
|
default:
|
||||||
|
return chatUserToResponses(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatSystemToResponses converts a system message.
|
||||||
|
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
text, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
content, err := json.Marshal(text)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatUserToResponses converts a user message, handling both plain strings and
|
||||||
|
// multi-modal content arrays.
|
||||||
|
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
// Try plain string first.
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||||
|
content, _ := json.Marshal(s)
|
||||||
|
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []ChatContentPart
|
||||||
|
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse user content: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseParts []ResponsesContentPart
|
||||||
|
for _, p := range parts {
|
||||||
|
switch p.Type {
|
||||||
|
case "text":
|
||||||
|
if p.Text != "" {
|
||||||
|
responseParts = append(responseParts, ResponsesContentPart{
|
||||||
|
Type: "input_text",
|
||||||
|
Text: p.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "image_url":
|
||||||
|
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||||
|
responseParts = append(responseParts, ResponsesContentPart{
|
||||||
|
Type: "input_image",
|
||||||
|
ImageURL: p.ImageURL.URL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := json.Marshal(responseParts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatAssistantToResponses converts an assistant message. If there is both
|
||||||
|
// text content and tool_calls, the text is emitted as an assistant message
|
||||||
|
// first, then each tool_call becomes a function_call item. If the content is
|
||||||
|
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||||
|
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
|
||||||
|
// Emit assistant message with output_text if content is non-empty.
|
||||||
|
if len(m.Content) > 0 {
|
||||||
|
s, err := parseAssistantContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s != "" {
|
||||||
|
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||||
|
partsJSON, err := json.Marshal(parts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit one function_call item per tool_call.
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
args := tc.Function.Arguments
|
||||||
|
if args == "" {
|
||||||
|
args = "{}"
|
||||||
|
}
|
||||||
|
items = append(items, ResponsesInputItem{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAssistantContent returns assistant content as plain text.
|
||||||
|
//
|
||||||
|
// Supported formats:
|
||||||
|
// - JSON string
|
||||||
|
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||||
|
//
|
||||||
|
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||||
|
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||||
|
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||||
|
// Keep compatibility with prior behavior: unsupported assistant content
|
||||||
|
// formats are ignored instead of failing the whole request conversion.
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
write := func(v string) error {
|
||||||
|
_, err := b.WriteString(v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, p := range parts {
|
||||||
|
typ, _ := p["type"].(string)
|
||||||
|
text, _ := p["text"].(string)
|
||||||
|
thinking, _ := p["thinking"].(string)
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case "thinking", "reasoning":
|
||||||
|
if thinking != "" {
|
||||||
|
if err := write("<thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write(thinking); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write("</thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if text != "" {
|
||||||
|
if err := write("<thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write(text); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write("</thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text != "" {
|
||||||
|
if err := write(text); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||||
|
// function_call_output item.
|
||||||
|
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
output, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "(empty)"
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{
|
||||||
|
Type: "function_call_output",
|
||||||
|
CallID: m.ToolCallID,
|
||||||
|
Output: output,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatFunctionToResponses converts a legacy function result message
|
||||||
|
// (role=function) into a function_call_output item. The Name field is used as
|
||||||
|
// call_id since legacy function calls do not carry a separate call_id.
|
||||||
|
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
output, err := parseChatContent(m.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "(empty)"
|
||||||
|
}
|
||||||
|
return []ResponsesInputItem{{
|
||||||
|
Type: "function_call_output",
|
||||||
|
CallID: m.Name,
|
||||||
|
Output: output,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||||
|
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||||
|
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err != nil {
|
||||||
|
return "", fmt.Errorf("parse content as string: %w", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||||
|
// function definitions to Responses API tool definitions.
|
||||||
|
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||||
|
var out []ResponsesTool
|
||||||
|
|
||||||
|
for _, t := range tools {
|
||||||
|
if t.Type != "function" || t.Function == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rt := ResponsesTool{
|
||||||
|
Type: "function",
|
||||||
|
Name: t.Function.Name,
|
||||||
|
Description: t.Function.Description,
|
||||||
|
Parameters: t.Function.Parameters,
|
||||||
|
Strict: t.Function.Strict,
|
||||||
|
}
|
||||||
|
out = append(out, rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy functions[] are treated as function-type tools.
|
||||||
|
for _, f := range functions {
|
||||||
|
rt := ResponsesTool{
|
||||||
|
Type: "function",
|
||||||
|
Name: f.Name,
|
||||||
|
Description: f.Description,
|
||||||
|
Parameters: f.Parameters,
|
||||||
|
Strict: f.Strict,
|
||||||
|
}
|
||||||
|
out = append(out, rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||||
|
// Responses API tool_choice value.
|
||||||
|
//
|
||||||
|
// "auto" → "auto"
|
||||||
|
// "none" → "none"
|
||||||
|
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||||
|
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||||
|
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Object form: {"name":"X"}
|
||||||
|
var obj struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]string{"name": obj.Name},
|
||||||
|
})
|
||||||
|
}
|
||||||
374
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
374
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||||
|
// Completions response. Text output items are concatenated into
|
||||||
|
// choices[0].message.content; function_call items become tool_calls.
|
||||||
|
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||||
|
id := resp.ID
|
||||||
|
if id == "" {
|
||||||
|
id = generateChatCmplID()
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ChatCompletionsResponse{
|
||||||
|
ID: id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
|
||||||
|
var contentText string
|
||||||
|
var reasoningText string
|
||||||
|
var toolCalls []ChatToolCall
|
||||||
|
|
||||||
|
for _, item := range resp.Output {
|
||||||
|
switch item.Type {
|
||||||
|
case "message":
|
||||||
|
for _, part := range item.Content {
|
||||||
|
if part.Type == "output_text" && part.Text != "" {
|
||||||
|
contentText += part.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
toolCalls = append(toolCalls, ChatToolCall{
|
||||||
|
ID: item.CallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: item.Name,
|
||||||
|
Arguments: item.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case "reasoning":
|
||||||
|
for _, s := range item.Summary {
|
||||||
|
if s.Type == "summary_text" && s.Text != "" {
|
||||||
|
reasoningText += s.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "web_search_call":
|
||||||
|
// silently consumed — results already incorporated into text output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := ChatMessage{Role: "assistant"}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
msg.ToolCalls = toolCalls
|
||||||
|
}
|
||||||
|
if contentText != "" {
|
||||||
|
raw, _ := json.Marshal(contentText)
|
||||||
|
msg.Content = raw
|
||||||
|
}
|
||||||
|
if reasoningText != "" {
|
||||||
|
msg.ReasoningContent = reasoningText
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||||
|
|
||||||
|
out.Choices = []ChatChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Message: msg,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
}}
|
||||||
|
|
||||||
|
if resp.Usage != nil {
|
||||||
|
usage := &ChatUsage{
|
||||||
|
PromptTokens: resp.Usage.InputTokens,
|
||||||
|
CompletionTokens: resp.Usage.OutputTokens,
|
||||||
|
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||||
|
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||||
|
switch status {
|
||||||
|
case "incomplete":
|
||||||
|
if details != nil && details.Reason == "max_output_tokens" {
|
||||||
|
return "length"
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
case "completed":
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return "tool_calls"
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
default:
|
||||||
|
return "stop"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||||
|
// SSE events into Chat Completions SSE chunks.
|
||||||
|
type ResponsesEventToChatState struct {
|
||||||
|
ID string
|
||||||
|
Model string
|
||||||
|
Created int64
|
||||||
|
SentRole bool
|
||||||
|
SawToolCall bool
|
||||||
|
SawText bool
|
||||||
|
Finalized bool // true after finish chunk has been emitted
|
||||||
|
NextToolCallIndex int // next sequential tool_call index to assign
|
||||||
|
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||||
|
IncludeUsage bool
|
||||||
|
Usage *ChatUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponsesEventToChatState returns an initialised stream state.
|
||||||
|
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||||
|
return &ResponsesEventToChatState{
|
||||||
|
ID: generateChatCmplID(),
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OutputIndexToToolIndex: make(map[int]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||||
|
// or more Chat Completions chunks, updating state as it goes.
|
||||||
|
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
switch evt.Type {
|
||||||
|
case "response.created":
|
||||||
|
return resToChatHandleCreated(evt, state)
|
||||||
|
case "response.output_text.delta":
|
||||||
|
return resToChatHandleTextDelta(evt, state)
|
||||||
|
case "response.output_item.added":
|
||||||
|
return resToChatHandleOutputItemAdded(evt, state)
|
||||||
|
case "response.function_call_arguments.delta":
|
||||||
|
return resToChatHandleFuncArgsDelta(evt, state)
|
||||||
|
case "response.reasoning_summary_text.delta":
|
||||||
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
|
case "response.reasoning_summary_text.done":
|
||||||
|
return nil
|
||||||
|
case "response.completed", "response.incomplete", "response.failed":
|
||||||
|
return resToChatHandleCompleted(evt, state)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||||
|
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||||
|
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||||
|
// this returns nil.
|
||||||
|
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if state.Finalized {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.Finalized = true
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||||
|
|
||||||
|
if state.IncludeUsage && state.Usage != nil {
|
||||||
|
chunks = append(chunks, ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{},
|
||||||
|
Usage: state.Usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||||
|
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||||
|
data, err := json.Marshal(chunk)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- internal handlers ---
|
||||||
|
|
||||||
|
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Response != nil {
|
||||||
|
if evt.Response.ID != "" {
|
||||||
|
state.ID = evt.Response.ID
|
||||||
|
}
|
||||||
|
if state.Model == "" && evt.Response.Model != "" {
|
||||||
|
state.Model = evt.Response.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Emit the role chunk.
|
||||||
|
if state.SentRole {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
role := "assistant"
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.SawText = true
|
||||||
|
content := evt.Delta
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state.SawToolCall = true
|
||||||
|
idx := state.NextToolCallIndex
|
||||||
|
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||||
|
state.NextToolCallIndex++
|
||||||
|
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
Index: &idx,
|
||||||
|
ID: evt.Item.CallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: evt.Item.Name,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
Index: &idx,
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Arguments: evt.Delta,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
if evt.Delta == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
reasoning := evt.Delta
|
||||||
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
state.Finalized = true
|
||||||
|
finishReason := "stop"
|
||||||
|
|
||||||
|
if evt.Response != nil {
|
||||||
|
if evt.Response.Usage != nil {
|
||||||
|
u := evt.Response.Usage
|
||||||
|
usage := &ChatUsage{
|
||||||
|
PromptTokens: u.InputTokens,
|
||||||
|
CompletionTokens: u.OutputTokens,
|
||||||
|
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||||
|
}
|
||||||
|
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||||
|
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
switch evt.Response.Status {
|
||||||
|
case "incomplete":
|
||||||
|
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
case "completed":
|
||||||
|
if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if state.SawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []ChatCompletionsChunk
|
||||||
|
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||||
|
|
||||||
|
if state.IncludeUsage && state.Usage != nil {
|
||||||
|
chunks = append(chunks, ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{},
|
||||||
|
Usage: state.Usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||||
|
return ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: delta,
|
||||||
|
FinishReason: nil,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||||
|
empty := ""
|
||||||
|
return ChatCompletionsChunk{
|
||||||
|
ID: state.ID,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: state.Created,
|
||||||
|
Model: state.Model,
|
||||||
|
Choices: []ChatChunkChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: ChatDelta{Content: &empty},
|
||||||
|
FinishReason: &finishReason,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||||
|
func generateChatCmplID() string {
|
||||||
|
b := make([]byte, 12)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return "chatcmpl-" + hex.EncodeToString(b)
|
||||||
|
}
|
||||||
@@ -329,6 +329,150 @@ type ResponsesStreamEvent struct {
|
|||||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// OpenAI Chat Completions API types
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||||
|
Tools []ChatTool `json:"tools,omitempty"`
|
||||||
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
|
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||||
|
|
||||||
|
// Legacy function calling (deprecated but still supported)
|
||||||
|
Functions []ChatFunction `json:"functions,omitempty"`
|
||||||
|
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStreamOptions configures streaming behavior.
|
||||||
|
type ChatStreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage is a single message in the Chat Completions conversation.
|
||||||
|
type ChatMessage struct {
|
||||||
|
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||||
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
|
||||||
|
// Legacy function calling
|
||||||
|
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatContentPart is a typed content part in a multi-modal message.
|
||||||
|
type ChatContentPart struct {
|
||||||
|
Type string `json:"type"` // "text" | "image_url"
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatImageURL contains the URL for an image content part.
|
||||||
|
type ChatImageURL struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatTool describes a tool available to the model.
|
||||||
|
type ChatTool struct {
|
||||||
|
Type string `json:"type"` // "function"
|
||||||
|
Function *ChatFunction `json:"function,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatFunction describes a function tool definition.
|
||||||
|
type ChatFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
Strict *bool `json:"strict,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatToolCall represents a tool call made by the assistant.
|
||||||
|
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||||
|
type ChatToolCall struct {
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"` // "function"
|
||||||
|
Function ChatFunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatFunctionCall contains the function name and arguments.
|
||||||
|
type ChatFunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"` // "chat.completion"
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChoice `json:"choices"`
|
||||||
|
Usage *ChatUsage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChoice is a single completion choice.
|
||||||
|
type ChatChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ChatMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatUsage holds token counts in Chat Completions format.
|
||||||
|
type ChatUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatTokenDetails provides a breakdown of token usage.
|
||||||
|
type ChatTokenDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||||
|
type ChatCompletionsChunk struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"` // "chat.completion.chunk"
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChunkChoice `json:"choices"`
|
||||||
|
Usage *ChatUsage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||||
|
type ChatChunkChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta ChatDelta `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatDelta carries incremental content in a streaming chunk.
|
||||||
|
type ChatDelta struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||||
|
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Shared constants
|
// Shared constants
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ const (
|
|||||||
|
|
||||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||||
var DroppedBetas = []string{BetaFastMode}
|
var DroppedBetas = []string{}
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|||||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
|||||||
return []Model{
|
return []Model{
|
||||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
models := DefaultModels()
|
||||||
|
byName := make(map[string]Model, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
byName[model.Name] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"models/gemini-2.5-flash-image",
|
||||||
|
"models/gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range required {
|
||||||
|
model, ok := byName[name]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected fallback model %q to exist", name)
|
||||||
|
}
|
||||||
|
if len(model.SupportedGenerationMethods) == 0 {
|
||||||
|
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,10 +13,12 @@ type Model struct {
|
|||||||
var DefaultModels = []Model{
|
var DefaultModels = []Model{
|
||||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||||
|
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||||
|
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultTestModel is the default model to preselect in test flows.
|
// DefaultTestModel is the default model to preselect in test flows.
|
||||||
|
|||||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package geminicli
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
byID := make(map[string]Model, len(DefaultModels))
|
||||||
|
for _, model := range DefaultModels {
|
||||||
|
byID[model.ID] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
"gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range required {
|
||||||
|
if _, ok := byID[id]; !ok {
|
||||||
|
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
|||||||
type OpenAIAuthClaims struct {
|
type OpenAIAuthClaims struct {
|
||||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||||
|
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
Organizations []OrganizationClaim `json:"organizations"`
|
Organizations []OrganizationClaim `json:"organizations"`
|
||||||
}
|
}
|
||||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
|||||||
return params.Encode()
|
return params.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||||
//
|
|
||||||
// https://auth.openai.com/.well-known/jwks.json
|
|
||||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
|
||||||
parts := strings.Split(idToken, ".")
|
parts := strings.Split(idToken, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
|||||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||||
|
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||||
|
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||||
|
//
|
||||||
|
// https://auth.openai.com/.well-known/jwks.json
|
||||||
|
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||||
|
claims, err := DecodeIDToken(idToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||||
const clockSkewTolerance = 120 // 秒
|
const clockSkewTolerance = 120 // 秒
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
|||||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserInfo represents user information extracted from ID Token claims.
|
// UserInfo represents user information extracted from ID Token claims.
|
||||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
|||||||
Email string
|
Email string
|
||||||
ChatGPTAccountID string
|
ChatGPTAccountID string
|
||||||
ChatGPTUserID string
|
ChatGPTUserID string
|
||||||
|
PlanType string
|
||||||
UserID string
|
UserID string
|
||||||
OrganizationID string
|
OrganizationID string
|
||||||
Organizations []OrganizationClaim
|
Organizations []OrganizationClaim
|
||||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
|||||||
if c.OpenAIAuth != nil {
|
if c.OpenAIAuth != nil {
|
||||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||||
|
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||||
info.UserID = c.OpenAIAuth.UserID
|
info.UserID = c.OpenAIAuth.UserID
|
||||||
info.Organizations = c.OpenAIAuth.Organizations
|
info.Organizations = c.OpenAIAuth.Organizations
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accepted 返回异步接受响应 (HTTP 202)
|
||||||
|
func Accepted(c *gin.Context, data any) {
|
||||||
|
c.JSON(http.StatusAccepted, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: "accepted",
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Error 返回错误响应
|
// Error 返回错误响应
|
||||||
func Error(c *gin.Context, statusCode int, message string) {
|
func Error(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, Response{
|
c.JSON(statusCode, Response{
|
||||||
|
|||||||
@@ -3,6 +3,28 @@ package usagestats
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelSourceRequested = "requested"
|
||||||
|
ModelSourceUpstream = "upstream"
|
||||||
|
ModelSourceMapping = "mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsValidModelSource(source string) bool {
|
||||||
|
switch source {
|
||||||
|
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeModelSource(source string) string {
|
||||||
|
if IsValidModelSource(source) {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
return ModelSourceRequested
|
||||||
|
}
|
||||||
|
|
||||||
// DashboardStats 仪表盘统计
|
// DashboardStats 仪表盘统计
|
||||||
type DashboardStats struct {
|
type DashboardStats struct {
|
||||||
// 用户统计
|
// 用户统计
|
||||||
@@ -81,6 +103,22 @@ type ModelStat struct {
|
|||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EndpointStat represents usage statistics for a single request endpoint.
|
||||||
|
type EndpointStat struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupUsageSummary represents today's and cumulative cost for a single group.
|
||||||
|
type GroupUsageSummary struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
TodayCost float64 `json:"today_cost"`
|
||||||
|
TotalCost float64 `json:"total_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
// GroupStat represents usage statistics for a single group
|
// GroupStat represents usage statistics for a single group
|
||||||
type GroupStat struct {
|
type GroupStat struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
@@ -96,12 +134,49 @@ type UserUsageTrendPoint struct {
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingItem represents a user spending ranking row.
|
||||||
|
type UserSpendingRankingItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||||
|
type UserSpendingRankingResponse struct {
|
||||||
|
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
|
||||||
|
type UserBreakdownItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||||
|
type UserBreakdownDimension struct {
|
||||||
|
GroupID int64 // filter by group_id (>0 to enable)
|
||||||
|
Model string // filter by model name (non-empty to enable)
|
||||||
|
ModelType string // "requested", "upstream", or "mapping"
|
||||||
|
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||||
|
EndpointType string // "inbound", "upstream", or "path"
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||||
type APIKeyUsageTrendPoint struct {
|
type APIKeyUsageTrendPoint struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
@@ -163,15 +238,18 @@ type UsageLogFilters struct {
|
|||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
type UsageStats struct {
|
type UsageStats struct {
|
||||||
TotalRequests int64 `json:"total_requests"`
|
TotalRequests int64 `json:"total_requests"`
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||||
|
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
@@ -238,7 +316,9 @@ type AccountUsageSummary struct {
|
|||||||
|
|
||||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||||
type AccountUsageStatsResponse struct {
|
type AccountUsageStatsResponse struct {
|
||||||
History []AccountUsageHistory `json:"history"`
|
History []AccountUsageHistory `json:"history"`
|
||||||
Summary AccountUsageSummary `json:"summary"`
|
Summary AccountUsageSummary `json:"summary"`
|
||||||
Models []ModelStat `json:"models"`
|
Models []ModelStat `json:"models"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||||
}
|
}
|
||||||
|
|||||||
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package usagestats
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsValidModelSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "requested", source: ModelSourceRequested, want: true},
|
||||||
|
{name: "upstream", source: ModelSourceUpstream, want: true},
|
||||||
|
{name: "mapping", source: ModelSourceMapping, want: true},
|
||||||
|
{name: "invalid", source: "foobar", want: false},
|
||||||
|
{name: "empty", source: "", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := IsValidModelSource(tc.source); got != tc.want {
|
||||||
|
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeModelSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
|
||||||
|
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
|
||||||
|
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
|
||||||
|
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
|
||||||
|
{name: "empty falls back", source: "", want: ModelSourceRequested},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := NormalizeModelSource(tc.source); got != tc.want {
|
||||||
|
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -50,6 +51,19 @@ type accountRepository struct {
|
|||||||
schedulerCache service.SchedulerCache
|
schedulerCache service.SchedulerCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||||
|
"codex_primary_",
|
||||||
|
"codex_secondary_",
|
||||||
|
"codex_5h_",
|
||||||
|
"codex_7d_",
|
||||||
|
"passive_usage_",
|
||||||
|
}
|
||||||
|
|
||||||
|
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||||
|
"codex_usage_updated_at": {},
|
||||||
|
"session_window_utilization": {},
|
||||||
|
}
|
||||||
|
|
||||||
// NewAccountRepository 创建账户仓储实例。
|
// NewAccountRepository 创建账户仓储实例。
|
||||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||||
@@ -384,9 +398,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||||
}
|
}
|
||||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||||
}
|
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -460,7 +474,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(dbaccount.NameContainsFold(search))
|
q = q.Where(dbaccount.NameContainsFold(search))
|
||||||
}
|
}
|
||||||
if groupID > 0 {
|
if groupID == service.AccountListGroupUngrouped {
|
||||||
|
q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups()))
|
||||||
|
} else if groupID > 0 {
|
||||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1185,12 +1201,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||||
|
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||||
|
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for key := range updates {
|
||||||
|
if isSchedulerNeutralExtraKey(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSchedulerNeutralExtraKey(key string) bool {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||||
|
if strings.HasPrefix(key, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@@ -1678,8 +1730,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
|||||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||||
|
|
||||||
|
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||||
|
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||||
|
const dailyExpiredExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||||
|
const weeklyExpiredExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||||
|
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||||
|
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||||
|
const nextDailyResetAtExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN to_char((
|
||||||
|
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||||
|
CASE WHEN NOW() >= (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||||
|
THEN (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ '1 day'::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- NOW() is before today's reset point → next reset is today
|
||||||
|
ELSE (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
END
|
||||||
|
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||||
|
ELSE NULL END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||||
|
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||||
|
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||||
|
const nextWeeklyResetAtExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN to_char((
|
||||||
|
-- Compute this week's reset point in the configured timezone
|
||||||
|
-- Step 1: get today's date at reset hour in configured tz
|
||||||
|
-- Step 2: compute days forward to target weekday
|
||||||
|
-- Step 3: if same day but past reset hour, advance 7 days
|
||||||
|
CASE
|
||||||
|
WHEN (
|
||||||
|
-- days_forward = (target_day - current_day + 7) % 7
|
||||||
|
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||||
|
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||||
|
+ 7) % 7
|
||||||
|
) = 0 AND NOW() >= (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- Same weekday and past reset hour → next week
|
||||||
|
THEN (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ '7 days'::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
ELSE (
|
||||||
|
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ ((
|
||||||
|
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||||
|
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||||
|
+ 7) % 7
|
||||||
|
) || ' days')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
END
|
||||||
|
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||||
|
ELSE NULL END
|
||||||
|
)`
|
||||||
|
|
||||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||||
|
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||||
rows, err := r.sql.QueryContext(ctx,
|
rows, err := r.sql.QueryContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
@@ -1690,31 +1830,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
|||||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_daily_used',
|
'quota_daily_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
'quota_daily_start',
|
'quota_daily_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
-- 固定模式重置时更新下次重置时间
|
||||||
|
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_weekly_used',
|
'quota_weekly_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
'quota_weekly_start',
|
'quota_weekly_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
-- 固定模式重置时更新下次重置时间
|
||||||
|
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
), updated_at = NOW()
|
), updated_at = NOW()
|
||||||
WHERE id = $2 AND deleted_at IS NULL
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
@@ -1747,12 +1891,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||||
|
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx,
|
_, err := r.sql.ExecContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
COALESCE(extra, '{}'::jsonb)
|
COALESCE(extra, '{}'::jsonb)
|
||||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||||
WHERE id = $1 AND deleted_at IS NULL`,
|
WHERE id = $1 AND deleted_at IS NULL`,
|
||||||
id)
|
id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user