mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 00:10:21 +08:00
Compare commits
229 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0236b97d49 | ||
|
|
26f6b1eeff | ||
|
|
dc447ccebe | ||
|
|
7ec29638f4 | ||
|
|
4c9562af20 | ||
|
|
71942fd322 | ||
|
|
550b979ac5 | ||
|
|
3878a5a46f | ||
|
|
e443a6a1ea | ||
|
|
963494ec6f | ||
|
|
525cdb8830 | ||
|
|
a6764e82f2 | ||
|
|
8027531d07 | ||
|
|
30706355a4 | ||
|
|
dfe99507b8 | ||
|
|
c1717c9a6c | ||
|
|
1fd1a58a7a | ||
|
|
fad07507be | ||
|
|
a20c211162 | ||
|
|
9f6ab6b817 | ||
|
|
bf3d6c0e6e | ||
|
|
241023f3fc | ||
|
|
1292c44b41 | ||
|
|
b4fce47049 | ||
|
|
e7780cd8c8 | ||
|
|
af96c8ea53 | ||
|
|
7d26b81075 | ||
|
|
b8ada63ac3 | ||
|
|
cfaac12af1 | ||
|
|
6028efd26c | ||
|
|
62a566ef2c | ||
|
|
94419f434c | ||
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
bd9d2671d7 | ||
|
|
62b40636e0 | ||
|
|
eeff451bc5 | ||
|
|
56fcb20f94 | ||
|
|
7134266acf | ||
|
|
2e4ac88ad9 | ||
|
|
51547fa216 | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
aa6047c460 | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 | ||
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
6826149a8f | ||
|
|
c9debc50b1 |
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
|||||||
# Go 源代码文件
|
# Go 源代码文件
|
||||||
*.go text eol=lf
|
*.go text eol=lf
|
||||||
|
|
||||||
|
# 前端 源代码文件
|
||||||
|
*.ts text eol=lf
|
||||||
|
*.tsx text eol=lf
|
||||||
|
*.js text eol=lf
|
||||||
|
*.jsx text eol=lf
|
||||||
|
*.vue text eol=lf
|
||||||
|
|
||||||
# Shell 脚本
|
# Shell 脚本
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -17,7 +17,6 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
cache-dependency-path: backend/go.sum
|
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.1'
|
go version | grep -q 'go1.26.1'
|
||||||
@@ -37,7 +36,6 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
cache-dependency-path: backend/go.sum
|
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.1'
|
go version | grep -q 'go1.26.1'
|
||||||
|
|||||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -271,3 +271,36 @@ jobs:
|
|||||||
parse_mode: "Markdown",
|
parse_mode: "Markdown",
|
||||||
disable_web_page_preview: true
|
disable_web_page_preview: true
|
||||||
}')"
|
}')"
|
||||||
|
|
||||||
|
sync-version-file:
|
||||||
|
needs: [release]
|
||||||
|
if: ${{ needs.release.result == 'success' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout default branch
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.repository.default_branch }}
|
||||||
|
|
||||||
|
- name: Sync VERSION file to released tag
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
VERSION=${{ github.event.inputs.tag }}
|
||||||
|
VERSION=${VERSION#v}
|
||||||
|
else
|
||||||
|
VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
|
fi
|
||||||
|
|
||||||
|
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
|
||||||
|
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
|
||||||
|
echo "VERSION file already matches $VERSION"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$VERSION" > backend/cmd/server/VERSION
|
||||||
|
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
|
||||||
|
git push origin HEAD:${{ github.event.repository.default_branch }}
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -78,7 +78,6 @@ Desktop.ini
|
|||||||
# ===================
|
# ===================
|
||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
logs/
|
|
||||||
*.tmp
|
*.tmp
|
||||||
*.temp
|
*.temp
|
||||||
*.log
|
*.log
|
||||||
@@ -129,15 +128,8 @@ deploy/docker-compose.override.yml
|
|||||||
vite.config.js
|
vite.config.js
|
||||||
docs/*
|
docs/*
|
||||||
.serena/
|
.serena/
|
||||||
|
|
||||||
# ===================
|
|
||||||
# 压测工具
|
|
||||||
# ===================
|
|
||||||
tools/loadtest/
|
|
||||||
# Antigravity Manager
|
|
||||||
Antigravity-Manager/
|
|
||||||
antigravity_projectid_fix.patch
|
|
||||||
.codex/
|
.codex/
|
||||||
frontend/coverage/
|
frontend/coverage/
|
||||||
aicodex
|
aicodex
|
||||||
output/
|
output/
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -76,6 +78,8 @@ dockers:
|
|||||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -89,6 +93,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/amd64"
|
- "--platform=linux/amd64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
@@ -102,6 +108,8 @@ dockers:
|
|||||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||||
dockerfile: Dockerfile.goreleaser
|
dockerfile: Dockerfile.goreleaser
|
||||||
use: buildx
|
use: buildx
|
||||||
|
extra_files:
|
||||||
|
- deploy/docker-entrypoint.sh
|
||||||
build_flag_templates:
|
build_flag_templates:
|
||||||
- "--platform=linux/arm64"
|
- "--platform=linux/arm64"
|
||||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
|||||||
31
Dockerfile
31
Dockerfile
@@ -9,6 +9,7 @@
|
|||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.21
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
ARG GOSUMDB=sum.golang.google.cn
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|
||||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
|||||||
./cmd/server
|
./cmd/server
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Stage 3: Final Runtime Image
|
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Stage 4: Final Runtime Image
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
FROM ${ALPINE_IMAGE}
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
@@ -86,8 +92,21 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
|||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
|
su-exec \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||||
|
# This ensures version consistency between backup tools and the database server
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
@@ -102,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||||
|
|
||||||
# Switch to non-root user
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
USER sub2api
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
# Expose port (can be overridden by SERVER_PORT env var)
|
# Expose port (can be overridden by SERVER_PORT env var)
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
@@ -112,5 +132,6 @@ EXPOSE 8080
|
|||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
# Run the application
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
@@ -5,7 +5,12 @@
|
|||||||
# It only packages the pre-built binary, no compilation needed.
|
# It only packages the pre-built binary, no compilation needed.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
FROM alpine:3.19
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
|
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||||
LABEL description="Sub2API - AI API Gateway Platform"
|
LABEL description="Sub2API - AI API Gateway Platform"
|
||||||
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
|
|||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
curl \
|
curl \
|
||||||
|
su-exec \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||||
|
# restore work in the runtime container without requiring Docker socket access.
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
|
|||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||||
|
|
||||||
USER sub2api
|
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||||
|
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||||
|
RUN chmod +x /app/docker-entrypoint.sh
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||||
|
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||||
|
CMD ["/app/sub2api"]
|
||||||
|
|||||||
40
README.md
40
README.md
@@ -8,27 +8,31 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||||
|
|
||||||
English | [中文](README_CN.md)
|
English | [中文](README_CN.md)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
Try Sub2API online: **https://demo.sub2api.org/**
|
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||||
|
|
||||||
| Email | Password |
|
| Email | Password |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -39,6 +43,25 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||||
- **Rate Limiting** - Configurable request and token rate limits
|
- **Rate Limiting** - Configurable request and token rate limits
|
||||||
- **Admin Dashboard** - Web interface for monitoring and management
|
- **Admin Dashboard** - Web interface for monitoring and management
|
||||||
|
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||||
|
|
||||||
|
## Don't Want to Self-Host?
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
## Ecosystem
|
||||||
|
|
||||||
|
Community projects that extend or integrate with Sub2API:
|
||||||
|
|
||||||
|
| Project | Description | Features |
|
||||||
|
|---------|-------------|----------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
@@ -51,10 +74,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Documentation
|
## Nginx Reverse Proxy Note
|
||||||
|
|
||||||
- Dependency Security: `docs/dependency-security.md`
|
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
43
README_CN.md
43
README_CN.md
@@ -8,27 +8,30 @@
|
|||||||
[](https://redis.io/)
|
[](https://redis.io/)
|
||||||
[](https://www.docker.com/)
|
[](https://www.docker.com/)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
**AI API 网关平台 - 订阅配额分发管理**
|
**AI API 网关平台 - 订阅配额分发管理**
|
||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||||
---
|
---
|
||||||
|
|
||||||
## 在线体验
|
## 在线体验
|
||||||
|
|
||||||
体验地址:**https://v2.pincc.ai/**
|
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||||
|
|
||||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||||
|
|
||||||
| 邮箱 | 密码 |
|
| 邮箱 | 密码 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| admin@sub2api.com | admin123 |
|
| admin@sub2api.org | admin123 |
|
||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||||
|
|
||||||
## 核心功能
|
## 核心功能
|
||||||
|
|
||||||
@@ -39,6 +42,25 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- **并发控制** - 用户级和账号级并发限制
|
- **并发控制** - 用户级和账号级并发限制
|
||||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||||
- **管理后台** - Web 界面进行监控和管理
|
- **管理后台** - Web 界面进行监控和管理
|
||||||
|
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||||
|
|
||||||
|
## 不想自建?试试官方中转
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||||
|
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
## 生态项目
|
||||||
|
|
||||||
|
围绕 Sub2API 的社区扩展与集成项目:
|
||||||
|
|
||||||
|
| 项目 | 说明 | 功能 |
|
||||||
|
|------|------|------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
@@ -51,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 文档
|
## Nginx 反向代理注意事项
|
||||||
|
|
||||||
- 依赖安全:`docs/dependency-security.md`
|
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||||
|
|
||||||
|
```nginx
|
||||||
|
underscores_in_headers on;
|
||||||
|
```
|
||||||
|
|
||||||
|
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## OpenAI Responses 兼容注意事项
|
|
||||||
|
|
||||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
|
||||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -1 +1 @@
|
|||||||
0.1.96.1
|
0.1.88
|
||||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
// Server layer ProviderSet
|
// Server layer ProviderSet
|
||||||
server.ProviderSet,
|
server.ProviderSet,
|
||||||
|
|
||||||
|
// Privacy client factory for OpenAI training opt-out
|
||||||
|
providePrivacyClientFactory,
|
||||||
|
|
||||||
// BuildInfo provider
|
// BuildInfo provider
|
||||||
provideServiceBuildInfo,
|
provideServiceBuildInfo,
|
||||||
|
|
||||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
|||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
|
backupSvc *service.BackupService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"BackupService", func() error {
|
||||||
|
if backupSvc != nil {
|
||||||
|
backupSvc.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
|
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
@@ -104,11 +105,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||||
@@ -122,6 +123,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
@@ -130,20 +132,26 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
rpmCache := repository.NewRPMCache(redisClient)
|
rpmCache := repository.NewRPMCache(redisClient)
|
||||||
|
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||||
|
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
dataManagementService := service.NewDataManagementService()
|
dataManagementService := service.NewDataManagementService()
|
||||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||||
|
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||||
|
dbDumper := repository.NewPgDumper(configConfig)
|
||||||
|
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||||
|
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
@@ -160,11 +168,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
@@ -199,7 +207,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
@@ -226,11 +234,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -245,6 +253,10 @@ type Application struct {
|
|||||||
Cleanup func()
|
Cleanup func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
@@ -279,6 +291,7 @@ func provideCleanup(
|
|||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
|
backupSvc *service.BackupService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -414,6 +427,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"BackupService", func() error {
|
||||||
|
if backupSvc != nil {
|
||||||
|
backupSvc.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
antigravityOAuthSvc,
|
antigravityOAuthSvc,
|
||||||
nil, // openAIGateway
|
nil, // openAIGateway
|
||||||
nil, // scheduledTestRunner
|
nil, // scheduledTestRunner
|
||||||
|
nil, // backupSvc
|
||||||
)
|
)
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
|
|||||||
@@ -62,28 +62,26 @@ type Group struct {
|
|||||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||||
// allow Claude Code client only
|
// 是否仅允许 Claude Code 客户端
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// fallback group for non-Claude-Code requests
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
// fallback group for invalid request
|
// 无效请求兜底使用的分组 ID
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
// model routing config: pattern -> account ids
|
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
// whether model routing is enabled
|
// 是否启用模型路由配置
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||||
// whether MCP XML prompt injection is enabled
|
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
||||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||||
// supported model scopes: claude, gemini_text, gemini_image
|
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
// group display order, lower comes first
|
// 分组显示排序,数值越小越靠前
|
||||||
SortOrder int `json:"sort_order,omitempty"`
|
SortOrder int `json:"sort_order,omitempty"`
|
||||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||||
// simulate claude usage as claude-max style (1h cache write)
|
|
||||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
|
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -192,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
@@ -433,12 +431,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.DefaultMappedModel = value.String
|
_m.DefaultMappedModel = value.String
|
||||||
}
|
}
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SimulateClaudeMaxEnabled = value.Bool
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -638,9 +630,6 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("default_mapped_model=")
|
builder.WriteString("default_mapped_model=")
|
||||||
builder.WriteString(_m.DefaultMappedModel)
|
builder.WriteString(_m.DefaultMappedModel)
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("simulate_claude_max_enabled=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
|
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,8 +79,6 @@ const (
|
|||||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||||
FieldDefaultMappedModel = "default_mapped_model"
|
FieldDefaultMappedModel = "default_mapped_model"
|
||||||
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
|
|
||||||
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
|
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -188,7 +186,6 @@ var Columns = []string{
|
|||||||
FieldSortOrder,
|
FieldSortOrder,
|
||||||
FieldAllowMessagesDispatch,
|
FieldAllowMessagesDispatch,
|
||||||
FieldDefaultMappedModel,
|
FieldDefaultMappedModel,
|
||||||
FieldSimulateClaudeMaxEnabled,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -262,8 +259,6 @@ var (
|
|||||||
DefaultDefaultMappedModel string
|
DefaultDefaultMappedModel string
|
||||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
DefaultMappedModelValidator func(string) error
|
DefaultMappedModelValidator func(string) error
|
||||||
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
|
|
||||||
DefaultSimulateClaudeMaxEnabled bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -424,11 +419,6 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
|
|
||||||
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -205,11 +205,6 @@ func DefaultMappedModel(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
|
|
||||||
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1560,16 +1555,6 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
|
|
||||||
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
|
|
||||||
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -452,20 +452,6 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
|
|
||||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSimulateClaudeMaxEnabled(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -663,10 +649,6 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultDefaultMappedModel
|
v := group.DefaultDefaultMappedModel
|
||||||
_c.mutation.SetDefaultMappedModel(v)
|
_c.mutation.SetDefaultMappedModel(v)
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
|
||||||
v := group.DefaultSimulateClaudeMaxEnabled
|
|
||||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -748,9 +730,6 @@ func (_c *GroupCreate) check() error {
|
|||||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
|
||||||
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -906,10 +885,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
_node.DefaultMappedModel = value
|
_node.DefaultMappedModel = value
|
||||||
}
|
}
|
||||||
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
|
|
||||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
|
||||||
_node.SimulateClaudeMaxEnabled = value
|
|
||||||
}
|
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1624,18 +1599,6 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -2332,20 +2295,6 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSimulateClaudeMaxEnabled()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -3208,20 +3157,6 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSimulateClaudeMaxEnabled()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -653,20 +653,6 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
|
|
||||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1163,9 +1149,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
|
||||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -2098,20 +2081,6 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
|
|
||||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -2638,9 +2607,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
|
||||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -410,7 +410,6 @@ var (
|
|||||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||||
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
|
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
@@ -717,6 +716,7 @@ var (
|
|||||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
{Name: "request_id", Type: field.TypeString, Size: 64},
|
{Name: "request_id", Type: field.TypeString, Size: 64},
|
||||||
{Name: "model", Type: field.TypeString, Size: 100},
|
{Name: "model", Type: field.TypeString, Size: 100},
|
||||||
|
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
||||||
@@ -756,31 +756,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -789,32 +789,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -829,17 +829,17 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id_created_at",
|
Name: "usagelog_group_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8252,7 +8252,6 @@ type GroupMutation struct {
|
|||||||
addsort_order *int
|
addsort_order *int
|
||||||
allow_messages_dispatch *bool
|
allow_messages_dispatch *bool
|
||||||
default_mapped_model *string
|
default_mapped_model *string
|
||||||
simulate_claude_max_enabled *bool
|
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -10069,42 +10068,6 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
|
|||||||
m.default_mapped_model = nil
|
m.default_mapped_model = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
|
|
||||||
m.simulate_claude_max_enabled = &b
|
|
||||||
}
|
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
|
|
||||||
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
|
|
||||||
v := m.simulate_claude_max_enabled
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SimulateClaudeMaxEnabled, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
|
|
||||||
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
|
|
||||||
m.simulate_claude_max_enabled = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -10463,7 +10426,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 33)
|
fields := make([]string, 0, 32)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -10560,9 +10523,6 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.default_mapped_model != nil {
|
if m.default_mapped_model != nil {
|
||||||
fields = append(fields, group.FieldDefaultMappedModel)
|
fields = append(fields, group.FieldDefaultMappedModel)
|
||||||
}
|
}
|
||||||
if m.simulate_claude_max_enabled != nil {
|
|
||||||
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
|
|
||||||
}
|
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -10635,8 +10595,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.AllowMessagesDispatch()
|
return m.AllowMessagesDispatch()
|
||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
return m.DefaultMappedModel()
|
return m.DefaultMappedModel()
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
return m.SimulateClaudeMaxEnabled()
|
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -10710,8 +10668,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldAllowMessagesDispatch(ctx)
|
return m.OldAllowMessagesDispatch(ctx)
|
||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
return m.OldDefaultMappedModel(ctx)
|
return m.OldDefaultMappedModel(ctx)
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
return m.OldSimulateClaudeMaxEnabled(ctx)
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -10945,13 +10901,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetDefaultMappedModel(v)
|
m.SetDefaultMappedModel(v)
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
v, ok := value.(bool)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -11385,9 +11334,6 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
m.ResetDefaultMappedModel()
|
m.ResetDefaultMappedModel()
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
m.ResetSimulateClaudeMaxEnabled()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -18293,6 +18239,7 @@ type UsageLogMutation struct {
|
|||||||
id *int64
|
id *int64
|
||||||
request_id *string
|
request_id *string
|
||||||
model *string
|
model *string
|
||||||
|
upstream_model *string
|
||||||
input_tokens *int
|
input_tokens *int
|
||||||
addinput_tokens *int
|
addinput_tokens *int
|
||||||
output_tokens *int
|
output_tokens *int
|
||||||
@@ -18630,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
|
|||||||
m.model = nil
|
m.model = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) SetUpstreamModel(s string) {
|
||||||
|
m.upstream_model = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
|
||||||
|
v := m.upstream_model
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.UpstreamModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) ClearUpstreamModel() {
|
||||||
|
m.upstream_model = nil
|
||||||
|
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
|
||||||
|
func (m *UsageLogMutation) UpstreamModelCleared() bool {
|
||||||
|
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetUpstreamModel resets all changes to the "upstream_model" field.
|
||||||
|
func (m *UsageLogMutation) ResetUpstreamModel() {
|
||||||
|
m.upstream_model = nil
|
||||||
|
delete(m.clearedFields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (m *UsageLogMutation) SetGroupID(i int64) {
|
func (m *UsageLogMutation) SetGroupID(i int64) {
|
||||||
m.group = &i
|
m.group = &i
|
||||||
@@ -20251,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 32)
|
fields := make([]string, 0, 33)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -20267,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.model != nil {
|
if m.model != nil {
|
||||||
fields = append(fields, usagelog.FieldModel)
|
fields = append(fields, usagelog.FieldModel)
|
||||||
}
|
}
|
||||||
|
if m.upstream_model != nil {
|
||||||
|
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
if m.group != nil {
|
if m.group != nil {
|
||||||
fields = append(fields, usagelog.FieldGroupID)
|
fields = append(fields, usagelog.FieldGroupID)
|
||||||
}
|
}
|
||||||
@@ -20366,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.RequestID()
|
return m.RequestID()
|
||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
return m.Model()
|
return m.Model()
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
return m.UpstreamModel()
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
return m.GroupID()
|
return m.GroupID()
|
||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
@@ -20439,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldRequestID(ctx)
|
return m.OldRequestID(ctx)
|
||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
return m.OldModel(ctx)
|
return m.OldModel(ctx)
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
return m.OldUpstreamModel(ctx)
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
return m.OldGroupID(ctx)
|
return m.OldGroupID(ctx)
|
||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
@@ -20537,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetModel(v)
|
m.SetModel(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetUpstreamModel(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
v, ok := value.(int64)
|
v, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -20975,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
|||||||
// mutation.
|
// mutation.
|
||||||
func (m *UsageLogMutation) ClearedFields() []string {
|
func (m *UsageLogMutation) ClearedFields() []string {
|
||||||
var fields []string
|
var fields []string
|
||||||
|
if m.FieldCleared(usagelog.FieldUpstreamModel) {
|
||||||
|
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||||
|
}
|
||||||
if m.FieldCleared(usagelog.FieldGroupID) {
|
if m.FieldCleared(usagelog.FieldGroupID) {
|
||||||
fields = append(fields, usagelog.FieldGroupID)
|
fields = append(fields, usagelog.FieldGroupID)
|
||||||
}
|
}
|
||||||
@@ -21016,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
|
|||||||
// error if the field is not defined in the schema.
|
// error if the field is not defined in the schema.
|
||||||
func (m *UsageLogMutation) ClearField(name string) error {
|
func (m *UsageLogMutation) ClearField(name string) error {
|
||||||
switch name {
|
switch name {
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
m.ClearUpstreamModel()
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
m.ClearGroupID()
|
m.ClearGroupID()
|
||||||
return nil
|
return nil
|
||||||
@@ -21066,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldModel:
|
case usagelog.FieldModel:
|
||||||
m.ResetModel()
|
m.ResetModel()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
m.ResetUpstreamModel()
|
||||||
|
return nil
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
m.ResetGroupID()
|
m.ResetGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -463,10 +463,6 @@ func init() {
|
|||||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||||
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
|
|
||||||
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
|
|
||||||
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
|
|
||||||
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
|
|
||||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||||
_ = idempotencyrecordMixinFields0
|
_ = idempotencyrecordMixinFields0
|
||||||
@@ -825,92 +821,96 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
|
||||||
|
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
|
||||||
|
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||||
|
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||||
usagelogDescInputTokens := usagelogFields[7].Descriptor()
|
usagelogDescInputTokens := usagelogFields[8].Descriptor()
|
||||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||||
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
|
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
|
||||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||||
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
|
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
|
||||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||||
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
|
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
|
||||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||||
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
|
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
|
||||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||||
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
|
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
|
||||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||||
usagelogDescInputCost := usagelogFields[13].Descriptor()
|
usagelogDescInputCost := usagelogFields[14].Descriptor()
|
||||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||||
usagelogDescOutputCost := usagelogFields[14].Descriptor()
|
usagelogDescOutputCost := usagelogFields[15].Descriptor()
|
||||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||||
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
|
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
|
||||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||||
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
|
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
|
||||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||||
usagelogDescTotalCost := usagelogFields[17].Descriptor()
|
usagelogDescTotalCost := usagelogFields[18].Descriptor()
|
||||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||||
usagelogDescActualCost := usagelogFields[18].Descriptor()
|
usagelogDescActualCost := usagelogFields[19].Descriptor()
|
||||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||||
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
|
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
|
||||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||||
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
usagelogDescBillingType := usagelogFields[22].Descriptor()
|
||||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||||
// usagelogDescStream is the schema descriptor for stream field.
|
// usagelogDescStream is the schema descriptor for stream field.
|
||||||
usagelogDescStream := usagelogFields[22].Descriptor()
|
usagelogDescStream := usagelogFields[23].Descriptor()
|
||||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||||
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
usagelogDescUserAgent := usagelogFields[26].Descriptor()
|
||||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||||
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
usagelogDescIPAddress := usagelogFields[27].Descriptor()
|
||||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||||
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
usagelogDescImageCount := usagelogFields[28].Descriptor()
|
||||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||||
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
usagelogDescImageSize := usagelogFields[29].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
usagelogDescMediaType := usagelogFields[30].Descriptor()
|
||||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||||
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
|
||||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ func (Group) Mixin() []ent.Mixin {
|
|||||||
|
|
||||||
func (Group) Fields() []ent.Field {
|
func (Group) Fields() []ent.Field {
|
||||||
return []ent.Field{
|
return []ent.Field{
|
||||||
|
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用
|
||||||
|
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
||||||
field.String("name").
|
field.String("name").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
NotEmpty(),
|
NotEmpty(),
|
||||||
@@ -49,6 +51,7 @@ func (Group) Fields() []ent.Field {
|
|||||||
MaxLen(20).
|
MaxLen(20).
|
||||||
Default(domain.StatusActive),
|
Default(domain.StatusActive),
|
||||||
|
|
||||||
|
// Subscription-related fields (added by migration 003)
|
||||||
field.String("platform").
|
field.String("platform").
|
||||||
MaxLen(50).
|
MaxLen(50).
|
||||||
Default(domain.PlatformAnthropic),
|
Default(domain.PlatformAnthropic),
|
||||||
@@ -70,6 +73,7 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Int("default_validity_days").
|
field.Int("default_validity_days").
|
||||||
Default(30),
|
Default(30),
|
||||||
|
|
||||||
|
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||||
field.Float("image_price_1k").
|
field.Float("image_price_1k").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
@@ -83,6 +87,7 @@ func (Group) Fields() []ent.Field {
|
|||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||||
|
|
||||||
|
// Sora 按次计费配置(阶段 1)
|
||||||
field.Float("sora_image_price_360").
|
field.Float("sora_image_price_360").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
@@ -104,38 +109,45 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Int64("sora_storage_quota_bytes").
|
field.Int64("sora_storage_quota_bytes").
|
||||||
Default(0),
|
Default(0),
|
||||||
|
|
||||||
|
// Claude Code 客户端限制 (added by migration 029)
|
||||||
field.Bool("claude_code_only").
|
field.Bool("claude_code_only").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("allow Claude Code client only"),
|
Comment("是否仅允许 Claude Code 客户端"),
|
||||||
field.Int64("fallback_group_id").
|
field.Int64("fallback_group_id").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("fallback group for non-Claude-Code requests"),
|
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||||
field.Int64("fallback_group_id_on_invalid_request").
|
field.Int64("fallback_group_id_on_invalid_request").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("fallback group for invalid request"),
|
Comment("无效请求兜底使用的分组 ID"),
|
||||||
|
|
||||||
|
// 模型路由配置 (added by migration 040)
|
||||||
field.JSON("model_routing", map[string][]int64{}).
|
field.JSON("model_routing", map[string][]int64{}).
|
||||||
Optional().
|
Optional().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("model routing config: pattern -> account ids"),
|
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||||
|
|
||||||
|
// 模型路由开关 (added by migration 041)
|
||||||
field.Bool("model_routing_enabled").
|
field.Bool("model_routing_enabled").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("whether model routing is enabled"),
|
Comment("是否启用模型路由配置"),
|
||||||
|
|
||||||
|
// MCP XML 协议注入开关 (added by migration 042)
|
||||||
field.Bool("mcp_xml_inject").
|
field.Bool("mcp_xml_inject").
|
||||||
Default(true).
|
Default(true).
|
||||||
Comment("whether MCP XML prompt injection is enabled"),
|
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
||||||
|
|
||||||
|
// 支持的模型系列 (added by migration 046)
|
||||||
field.JSON("supported_model_scopes", []string{}).
|
field.JSON("supported_model_scopes", []string{}).
|
||||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("supported model scopes: claude, gemini_text, gemini_image"),
|
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||||
|
|
||||||
|
// 分组排序 (added by migration 052)
|
||||||
field.Int("sort_order").
|
field.Int("sort_order").
|
||||||
Default(0).
|
Default(0).
|
||||||
Comment("group display order, lower comes first"),
|
Comment("分组显示排序,数值越小越靠前"),
|
||||||
|
|
||||||
// OpenAI Messages 调度配置 (added by migration 069)
|
// OpenAI Messages 调度配置 (added by migration 069)
|
||||||
field.Bool("allow_messages_dispatch").
|
field.Bool("allow_messages_dispatch").
|
||||||
@@ -145,9 +157,6 @@ func (Group) Fields() []ent.Field {
|
|||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
Default("").
|
Default("").
|
||||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||||
field.Bool("simulate_claude_max_enabled").
|
|
||||||
Default(false).
|
|
||||||
Comment("simulate claude usage as claude-max style (1h cache write)"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,11 +172,14 @@ func (Group) Edges() []ent.Edge {
|
|||||||
edge.From("allowed_users", User.Type).
|
edge.From("allowed_users", User.Type).
|
||||||
Ref("allowed_groups").
|
Ref("allowed_groups").
|
||||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||||
|
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||||
|
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Group) Indexes() []ent.Index {
|
func (Group) Indexes() []ent.Index {
|
||||||
return []ent.Index{
|
return []ent.Index{
|
||||||
|
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
|
||||||
index.Fields("status"),
|
index.Fields("status"),
|
||||||
index.Fields("platform"),
|
index.Fields("platform"),
|
||||||
index.Fields("subscription_type"),
|
index.Fields("subscription_type"),
|
||||||
|
|||||||
@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
field.String("model").
|
field.String("model").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
NotEmpty(),
|
NotEmpty(),
|
||||||
|
// UpstreamModel stores the actual upstream model name when model mapping
|
||||||
|
// is applied. NULL means no mapping — the requested model was used as-is.
|
||||||
|
field.String("upstream_model").
|
||||||
|
MaxLen(100).
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
field.Int64("group_id").
|
field.Int64("group_id").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ type UsageLog struct {
|
|||||||
RequestID string `json:"request_id,omitempty"`
|
RequestID string `json:"request_id,omitempty"`
|
||||||
// Model holds the value of the "model" field.
|
// Model holds the value of the "model" field.
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
// UpstreamModel holds the value of the "upstream_model" field.
|
||||||
|
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||||
// GroupID holds the value of the "group_id" field.
|
// GroupID holds the value of the "group_id" field.
|
||||||
GroupID *int64 `json:"group_id,omitempty"`
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
// SubscriptionID holds the value of the "subscription_id" field.
|
// SubscriptionID holds the value of the "subscription_id" field.
|
||||||
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Model = value.String
|
_m.Model = value.String
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldUpstreamModel:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UpstreamModel = new(string)
|
||||||
|
*_m.UpstreamModel = value.String
|
||||||
|
}
|
||||||
case usagelog.FieldGroupID:
|
case usagelog.FieldGroupID:
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||||
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString("model=")
|
builder.WriteString("model=")
|
||||||
builder.WriteString(_m.Model)
|
builder.WriteString(_m.Model)
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.UpstreamModel; v != nil {
|
||||||
|
builder.WriteString("upstream_model=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
if v := _m.GroupID; v != nil {
|
if v := _m.GroupID; v != nil {
|
||||||
builder.WriteString("group_id=")
|
builder.WriteString("group_id=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ const (
|
|||||||
FieldRequestID = "request_id"
|
FieldRequestID = "request_id"
|
||||||
// FieldModel holds the string denoting the model field in the database.
|
// FieldModel holds the string denoting the model field in the database.
|
||||||
FieldModel = "model"
|
FieldModel = "model"
|
||||||
|
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||||
|
FieldUpstreamModel = "upstream_model"
|
||||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||||
FieldGroupID = "group_id"
|
FieldGroupID = "group_id"
|
||||||
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
||||||
@@ -135,6 +137,7 @@ var Columns = []string{
|
|||||||
FieldAccountID,
|
FieldAccountID,
|
||||||
FieldRequestID,
|
FieldRequestID,
|
||||||
FieldModel,
|
FieldModel,
|
||||||
|
FieldUpstreamModel,
|
||||||
FieldGroupID,
|
FieldGroupID,
|
||||||
FieldSubscriptionID,
|
FieldSubscriptionID,
|
||||||
FieldInputTokens,
|
FieldInputTokens,
|
||||||
@@ -179,6 +182,8 @@ var (
|
|||||||
RequestIDValidator func(string) error
|
RequestIDValidator func(string) error
|
||||||
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
||||||
ModelValidator func(string) error
|
ModelValidator func(string) error
|
||||||
|
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||||
|
UpstreamModelValidator func(string) error
|
||||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||||
DefaultInputTokens int
|
DefaultInputTokens int
|
||||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||||
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByUpstreamModel orders the results by the upstream_model field.
|
||||||
|
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByGroupID orders the results by the group_id field.
|
// ByGroupID orders the results by the group_id field.
|
||||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||||
|
|||||||
@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
|
||||||
|
func UpstreamModel(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||||
func GroupID(v int64) predicate.UsageLog {
|
func GroupID(v int64) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||||
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelGT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelGTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelLT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelLTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelContains(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelIsNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelNotNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelEqualFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
|
||||||
|
func UpstreamModelContainsFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
|
||||||
|
}
|
||||||
|
|
||||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||||
func GroupIDEQ(v int64) predicate.UsageLog {
|
func GroupIDEQ(v int64) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||||
|
|||||||
@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
|
||||||
|
_c.mutation.SetUpstreamModel(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
||||||
_c.mutation.SetGroupID(v)
|
_c.mutation.SetGroupID(v)
|
||||||
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _c.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.InputTokens(); !ok {
|
if _, ok := _c.mutation.InputTokens(); !ok {
|
||||||
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
||||||
}
|
}
|
||||||
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
_node.Model = value
|
_node.Model = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
_node.UpstreamModel = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.InputTokens(); ok {
|
if value, ok := _c.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
_node.InputTokens = value
|
_node.InputTokens = value
|
||||||
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldUpstreamModel, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldUpstreamModel)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
|
||||||
|
u.SetNull(usagelog.FieldUpstreamModel)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldGroupID, v)
|
u.Set(usagelog.FieldGroupID, v)
|
||||||
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetUpstreamModel(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetUpstreamModel(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearUpstreamModel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
|
||||||
|
_u.mutation.SetUpstreamModel(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
|
||||||
|
_u.mutation.ClearUpstreamModel()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
||||||
_u.mutation.SetGroupID(v)
|
_u.mutation.SetGroupID(v)
|
||||||
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.UserAgent(); ok {
|
if v, ok := _u.mutation.UserAgent(); ok {
|
||||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.Model(); ok {
|
if value, ok := _u.mutation.Model(); ok {
|
||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UpstreamModelCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.InputTokens(); ok {
|
if value, ok := _u.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUpstreamModel sets the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.SetUpstreamModel(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetUpstreamModel(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||||
|
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ClearUpstreamModel()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetGroupID sets the "group_id" field.
|
// SetGroupID sets the "group_id" field.
|
||||||
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
||||||
_u.mutation.SetGroupID(v)
|
_u.mutation.SetGroupID(v)
|
||||||
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
|||||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.UserAgent(); ok {
|
if v, ok := _u.mutation.UserAgent(); ok {
|
||||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if value, ok := _u.mutation.Model(); ok {
|
if value, ok := _u.mutation.Model(); ok {
|
||||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.UpstreamModelCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.InputTokens(); ok {
|
if value, ok := _u.mutation.InputTokens(); ok {
|
||||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ require (
|
|||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||||
github.com/alitto/pond/v2 v2.6.2
|
github.com/alitto/pond/v2 v2.6.2
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||||
@@ -66,7 +66,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||||
github.com/aws/smithy-go v1.24.1 // indirect
|
github.com/aws/smithy-go v1.24.2 // indirect
|
||||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||||
@@ -87,7 +87,6 @@ require (
|
|||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
@@ -138,8 +137,6 @@ require (
|
|||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
|
||||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
|||||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||||
@@ -58,8 +58,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
|
|||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||||
@@ -94,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
|||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
|
||||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
|
||||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
|
||||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
@@ -124,8 +120,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
|
||||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
|
||||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
@@ -182,7 +176,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
@@ -202,8 +195,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
|||||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@@ -239,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
|
||||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@@ -274,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@@ -286,10 +273,6 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
|
||||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
|
||||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -331,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
|
||||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
@@ -342,8 +323,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
|||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -352,6 +331,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
||||||
@@ -441,11 +422,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||||
|
|||||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
|||||||
|
|
||||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
type DashboardAggregationRetentionConfig struct {
|
type DashboardAggregationRetentionConfig struct {
|
||||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
HourlyDays int `mapstructure:"hourly_days"`
|
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||||
DailyDays int `mapstructure:"daily_days"`
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageCleanupConfig 使用记录清理任务配置
|
// UsageCleanupConfig 使用记录清理任务配置
|
||||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
}
|
}
|
||||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
|||||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
}
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||||
|
}
|
||||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
}
|
}
|
||||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "dashboard aggregation disabled interval",
|
name: "dashboard aggregation disabled interval",
|
||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
|
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -81,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
// Claude Haiku → Sonnet(无 Haiku 支持)
|
// Claude Haiku → Sonnet(无 Haiku 支持)
|
||||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
"claude-haiku-4-5": "claude-sonnet-4-6",
|
||||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
|
||||||
// Gemini 2.5 白名单
|
// Gemini 2.5 白名单
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||||
@@ -113,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||||
|
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||||
|
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||||
|
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||||
|
var DefaultBedrockModelMapping = map[string]string{
|
||||||
|
// Claude Opus
|
||||||
|
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||||
|
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||||
|
// Claude Sonnet
|
||||||
|
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
// Claude Haiku
|
||||||
|
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
}
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
|||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -165,6 +165,8 @@ type AccountWithConcurrency struct {
|
|||||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||||
|
|
||||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||||
item := AccountWithConcurrency{
|
item := AccountWithConcurrency{
|
||||||
Account: dto.AccountFromService(account),
|
Account: dto.AccountFromService(account),
|
||||||
@@ -226,7 +228,20 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
var groupID int64
|
var groupID int64
|
||||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
if groupIDStr == accountListGroupUngroupedQueryValue {
|
||||||
|
groupID = service.AccountListGroupUngrouped
|
||||||
|
} else {
|
||||||
|
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if parsedGroupID < 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = parsedGroupID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||||
@@ -865,6 +880,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||||
|
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||||
|
|
||||||
return updatedAccount, "", nil
|
return updatedAccount, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1338,12 +1356,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
c.JSON(409, gin.H{
|
c.JSON(409, gin.H{
|
||||||
"error": "mixed_channel_warning",
|
"error": "mixed_channel_warning",
|
||||||
"message": mixedErr.Error(),
|
"message": mixedErr.Error(),
|
||||||
"details": gin.H{
|
|
||||||
"group_id": mixedErr.GroupID,
|
|
||||||
"group_name": mixedErr.GroupName,
|
|
||||||
"current_platform": mixedErr.CurrentPlatform,
|
|
||||||
"other_platform": mixedErr.OtherPlatform,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1499,7 +1511,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUsage handles getting account usage information
|
// GetUsage handles getting account usage information
|
||||||
// GET /api/v1/admin/accounts/:id/usage
|
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1507,7 +1519,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
source := c.DefaultQuery("source", "active")
|
||||||
|
|
||||||
|
var usage *service.UsageInfo
|
||||||
|
if source == "passive" {
|
||||||
|
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||||
|
} else {
|
||||||
|
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -1721,13 +1740,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
|
|
||||||
// Handle OpenAI accounts
|
// Handle OpenAI accounts
|
||||||
if account.IsOpenAI() {
|
if account.IsOpenAI() {
|
||||||
// For OAuth accounts: return default OpenAI models
|
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||||
if account.IsOAuth() {
|
if account.IsOpenAIPassthroughEnabled() {
|
||||||
response.Success(c, openai.DefaultModels)
|
response.Success(c, openai.DefaultModels)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For API Key accounts: check model_mapping
|
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
response.Success(c, openai.DefaultModels)
|
response.Success(c, openai.DefaultModels)
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type availableModelsAdminService struct {
|
||||||
|
*stubAdminService
|
||||||
|
account service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||||
|
if s.account.ID == id {
|
||||||
|
acc := s.account
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||||
|
svc := &availableModelsAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
account: service.Account{
|
||||||
|
ID: 42,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := setupAvailableModelsRouter(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Len(t, resp.Data, 1)
|
||||||
|
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||||
|
svc := &availableModelsAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
account: service.Account{
|
||||||
|
ID: 43,
|
||||||
|
Name: "openai-oauth-passthrough",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_passthrough": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := setupAvailableModelsRouter(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.NotEmpty(t, resp.Data)
|
||||||
|
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||||
|
}
|
||||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
require.Contains(t, resp["message"], "claude-max")
|
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||||
_, hasDetails := resp["details"]
|
_, hasDetails := resp["details"]
|
||||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
require.False(t, hasDetails)
|
require.False(t, hasDetails)
|
||||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
require.Contains(t, resp["message"], "claude-max")
|
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||||
_, hasDetails := resp["details"]
|
_, hasDetails := resp["details"]
|
||||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
require.False(t, hasDetails)
|
require.False(t, hasDetails)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
|||||||
adminSvc := newStubAdminService()
|
adminSvc := newStubAdminService()
|
||||||
|
|
||||||
userHandler := NewUserHandler(adminSvc, nil)
|
userHandler := NewUserHandler(adminSvc, nil)
|
||||||
groupHandler := NewGroupHandler(adminSvc)
|
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||||
proxyHandler := NewProxyHandler(adminSvc)
|
proxyHandler := NewProxyHandler(adminSvc)
|
||||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -179,6 +179,14 @@ func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) (
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
@@ -433,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure stub implements interface.
|
// Ensure stub implements interface.
|
||||||
var _ service.AdminService = (*stubAdminService)(nil)
|
var _ service.AdminService = (*stubAdminService)(nil)
|
||||||
|
|||||||
205
backend/internal/handler/admin/backup_handler.go
Normal file
205
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackupHandler struct {
|
||||||
|
backupService *service.BackupService
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||||
|
return &BackupHandler{
|
||||||
|
backupService: backupService,
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── S3 配置 ───
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||||
|
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||||
|
var req service.BackupS3Config
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||||
|
var req service.BackupS3Config
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 定时备份 ───
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||||
|
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||||
|
var req service.BackupScheduleConfig
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 备份操作 ───
|
||||||
|
|
||||||
|
type CreateBackupRequest struct {
|
||||||
|
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||||
|
var req CreateBackupRequest
|
||||||
|
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||||
|
|
||||||
|
expireDays := 14 // 默认14天过期
|
||||||
|
if req.ExpireDays != nil {
|
||||||
|
expireDays = *req.ExpireDays
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Accepted(c, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||||
|
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if records == nil {
|
||||||
|
records = []service.BackupRecord{}
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"items": records})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"url": url})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||||
|
|
||||||
|
type RestoreBackupRequest struct {
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RestoreBackupRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "password is required for restore operation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从上下文获取当前管理员用户 ID
|
||||||
|
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取管理员用户并验证密码
|
||||||
|
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !user.CheckPassword(req.Password) {
|
||||||
|
response.BadRequest(c, "incorrect admin password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Accepted(c, record)
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -272,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID, accountID, groupID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
modelSource := usagestats.ModelSourceRequested
|
||||||
var requestType *int16
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
var billingType *int8
|
var billingType *int8
|
||||||
@@ -296,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
groupID = id
|
groupID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
|
||||||
|
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||||
|
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelSource = rawModelSource
|
||||||
|
}
|
||||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -322,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
@@ -466,9 +475,62 @@ type BatchUsersUsageRequest struct {
|
|||||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
func parseRankingLimit(raw string) int {
|
||||||
|
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||||
|
if err != nil || limit <= 0 {
|
||||||
|
return 12
|
||||||
|
}
|
||||||
|
if limit > 50 {
|
||||||
|
return 50
|
||||||
|
}
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||||
|
// GET /api/v1/admin/dashboard/users-ranking
|
||||||
|
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||||
|
|
||||||
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
Start string `json:"start"`
|
||||||
|
End string `json:"end"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}{
|
||||||
|
Start: startTime.UTC().Format(time.RFC3339),
|
||||||
|
End: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user spending ranking")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"ranking": ranking.Ranking,
|
||||||
|
"total_actual_cost": ranking.TotalActualCost,
|
||||||
|
"total_requests": ranking.TotalRequests,
|
||||||
|
"total_tokens": ranking.TotalTokens,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
}
|
||||||
|
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||||
// POST /api/v1/admin/dashboard/users-usage
|
// POST /api/v1/admin/dashboard/users-usage
|
||||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||||
@@ -551,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
|||||||
c.Header("X-Snapshot-Cache", "miss")
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
response.Success(c, payload)
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||||
|
// GET /api/v1/admin/dashboard/user-breakdown
|
||||||
|
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||||
|
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
|
dim := usagestats.UserBreakdownDimension{}
|
||||||
|
if v := c.Query("group_id"); v != "" {
|
||||||
|
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
dim.GroupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dim.Model = c.Query("model")
|
||||||
|
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
|
||||||
|
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||||
|
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dim.ModelType = rawModelSource
|
||||||
|
dim.Endpoint = c.Query("endpoint")
|
||||||
|
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if v := c.Query("limit"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||||
|
c.Request.Context(), startTime, endTime, dim, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"users": stats,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
|||||||
trendStream *bool
|
trendStream *bool
|
||||||
modelRequestType *int16
|
modelRequestType *int16
|
||||||
modelStream *bool
|
modelStream *bool
|
||||||
|
rankingLimit int
|
||||||
|
ranking []usagestats.UserSpendingRankingItem
|
||||||
|
rankingTotal float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||||
@@ -49,6 +52,20 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
|||||||
return []usagestats.ModelStat{}, nil
|
return []usagestats.ModelStat{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
limit int,
|
||||||
|
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
s.rankingLimit = limit
|
||||||
|
return &usagestats.UserSpendingRankingResponse{
|
||||||
|
Ranking: s.ranking,
|
||||||
|
TotalActualCost: s.rankingTotal,
|
||||||
|
TotalRequests: 44,
|
||||||
|
TotalTokens: 1234,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
@@ -56,6 +73,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
|||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||||
|
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,3 +148,54 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsValidModelSource(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||||
|
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
|
repo := &dashboardUsageRepoCapture{
|
||||||
|
ranking: []usagestats.UserSpendingRankingItem{
|
||||||
|
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||||
|
},
|
||||||
|
rankingTotal: 88.8,
|
||||||
|
}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, 50, repo.rankingLimit)
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||||
|
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,229 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- mock repo ---
|
||||||
|
|
||||||
|
type userBreakdownRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
capturedDim usagestats.UserBreakdownDimension
|
||||||
|
capturedLimit int
|
||||||
|
result []usagestats.UserBreakdownItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||||
|
_ context.Context, _, _ time.Time,
|
||||||
|
dim usagestats.UserBreakdownDimension, limit int,
|
||||||
|
) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
r.capturedDim = dim
|
||||||
|
r.capturedLimit = limit
|
||||||
|
if r.result != nil {
|
||||||
|
return r.result, nil
|
||||||
|
}
|
||||||
|
return []usagestats.UserBreakdownItem{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
h := NewDashboardHandler(svc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- tests ---
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||||
|
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||||
|
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 100, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
// limit > 200 should fall back to default 50
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, 50, repo.capturedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{
|
||||||
|
result: []usagestats.UserBreakdownItem{
|
||||||
|
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||||
|
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
StartDate string `json:"start_date"`
|
||||||
|
EndDate string `json:"end_date"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, resp.Code)
|
||||||
|
require.Len(t, resp.Data.Users, 2)
|
||||||
|
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||||
|
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||||
|
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||||
|
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||||
|
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||||
|
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data struct {
|
||||||
|
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, resp.Data.Users)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||||
|
repo := &userBreakdownRepoCapture{}
|
||||||
|
router := newUserBreakdownRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||||
|
require.Empty(t, repo.capturedDim.Model)
|
||||||
|
require.Empty(t, repo.capturedDim.Endpoint)
|
||||||
|
}
|
||||||
@@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
|
|||||||
APIKeyID int64 `json:"api_key_id"`
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
|
ModelSource string `json:"model_source,omitempty"`
|
||||||
RequestType *int16 `json:"request_type"`
|
RequestType *int16 `json:"request_type"`
|
||||||
Stream *bool `json:"stream"`
|
Stream *bool `json:"stream"`
|
||||||
BillingType *int8 `json:"billing_type"`
|
BillingType *int8 `json:"billing_type"`
|
||||||
@@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
startTime, endTime time.Time,
|
startTime, endTime time.Time,
|
||||||
userID, apiKeyID, accountID, groupID int64,
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
modelSource string,
|
||||||
requestType *int16,
|
requestType *int16,
|
||||||
stream *bool,
|
stream *bool,
|
||||||
billingType *int8,
|
billingType *int8,
|
||||||
@@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
|
|||||||
APIKeyID: apiKeyID,
|
APIKeyID: apiKeyID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
|
ModelSource: usagestats.NormalizeModelSource(modelSource),
|
||||||
RequestType: requestType,
|
RequestType: requestType,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
})
|
})
|
||||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, hit, err
|
return nil, hit, err
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
|
|||||||
filters.APIKeyID,
|
filters.APIKeyID,
|
||||||
filters.AccountID,
|
filters.AccountID,
|
||||||
filters.GroupID,
|
filters.GroupID,
|
||||||
|
usagestats.ModelSourceRequested,
|
||||||
filters.RequestType,
|
filters.RequestType,
|
||||||
filters.Stream,
|
filters.Stream,
|
||||||
filters.BillingType,
|
filters.BillingType,
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,27 +17,80 @@ import (
|
|||||||
|
|
||||||
// GroupHandler handles admin group management
|
// GroupHandler handles admin group management
|
||||||
type GroupHandler struct {
|
type GroupHandler struct {
|
||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
|
dashboardService *service.DashboardService
|
||||||
|
groupCapacityService *service.GroupCapacityService
|
||||||
|
}
|
||||||
|
|
||||||
|
type optionalLimitField struct {
|
||||||
|
set bool
|
||||||
|
value *float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||||
|
f.set = true
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(data)
|
||||||
|
if bytes.Equal(trimmed, []byte("null")) {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var number float64
|
||||||
|
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var text string
|
||||||
|
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
number, err = strconv.ParseFloat(text, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||||
|
}
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||||
|
if !f.set {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if f.value != nil {
|
||||||
|
return f.value
|
||||||
|
}
|
||||||
|
zero := 0.0
|
||||||
|
return &zero
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupHandler creates a new admin group handler
|
// NewGroupHandler creates a new admin group handler
|
||||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||||
return &GroupHandler{
|
return &GroupHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
|
dashboardService: dashboardService,
|
||||||
|
groupCapacityService: groupCapacityService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroupRequest represents create group request
|
// CreateGroupRequest represents create group request
|
||||||
type CreateGroupRequest struct {
|
type CreateGroupRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
IsExclusive bool `json:"is_exclusive"`
|
IsExclusive bool `json:"is_exclusive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -46,10 +103,9 @@ type CreateGroupRequest struct {
|
|||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -63,16 +119,16 @@ type CreateGroupRequest struct {
|
|||||||
|
|
||||||
// UpdateGroupRequest represents update group request
|
// UpdateGroupRequest represents update group request
|
||||||
type UpdateGroupRequest struct {
|
type UpdateGroupRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
IsExclusive *bool `json:"is_exclusive"`
|
IsExclusive *bool `json:"is_exclusive"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -85,10 +141,9 @@ type UpdateGroupRequest struct {
|
|||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -193,9 +248,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -209,7 +264,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
@@ -247,9 +301,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -263,7 +317,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
@@ -315,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
|||||||
_ = groupID // TODO: implement actual stats
|
_ = groupID // TODO: implement actual stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||||
|
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||||
|
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||||
|
userTZ := c.Query("timezone")
|
||||||
|
now := timezone.NowInUserLocation(userTZ)
|
||||||
|
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||||
|
|
||||||
|
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group usage summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||||
|
// GET /api/v1/admin/groups/capacity-summary
|
||||||
|
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||||
|
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get group capacity summary")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, results)
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroupAPIKeys handles getting API keys in a group
|
// GetGroupAPIKeys handles getting API keys in a group
|
||||||
// GET /api/v1/admin/groups/:id/api-keys
|
// GET /api/v1/admin/groups/:id/api-keys
|
||||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||||
@@ -360,6 +440,51 @@ func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
|||||||
response.Success(c, entries)
|
response.Success(c, entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||||
|
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||||
|
type BatchSetGroupRateMultipliersRequest struct {
|
||||||
|
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||||
|
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchSetGroupRateMultipliersRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||||
type UpdateSortOrderRequest struct {
|
type UpdateSortOrderRequest struct {
|
||||||
Updates []struct {
|
Updates []struct {
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
Platform: platform,
|
Platform: platform,
|
||||||
Type: "oauth",
|
Type: "oauth",
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
|
Extra: nil,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
|||||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||||
|
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||||
type CreateAndRedeemCodeRequest struct {
|
type CreateAndRedeemCodeRequest struct {
|
||||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||||
Value float64 `json:"value" binding:"required,gt=0"`
|
Value float64 `json:"value" binding:"required,gt=0"`
|
||||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||||
Notes string `json:"notes"`
|
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||||
|
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||||
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all redeem codes with pagination
|
// List handles listing all redeem codes with pagination
|
||||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Code = strings.TrimSpace(req.Code)
|
req.Code = strings.TrimSpace(req.Code)
|
||||||
|
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||||
|
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||||
|
if req.Type == "" {
|
||||||
|
req.Type = "balance"
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Type == "subscription" {
|
||||||
|
if req.GroupID == nil {
|
||||||
|
response.BadRequest(c, "group_id is required for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ValidityDays <= 0 {
|
||||||
|
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
Type: req.Type,
|
Type: req.Type,
|
||||||
Value: req.Value,
|
Value: req.Value,
|
||||||
Status: service.StatusUnused,
|
Status: service.StatusUnused,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
|
GroupID: req.GroupID,
|
||||||
|
ValidityDays: req.ValidityDays,
|
||||||
})
|
})
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||||
|
|||||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||||
|
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||||
|
// parameter-validation layer that runs before any service call.
|
||||||
|
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||||
|
return &RedeemHandler{
|
||||||
|
adminService: newStubAdminService(),
|
||||||
|
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||||
|
// status code. For cases that pass validation and proceed into the service layer,
|
||||||
|
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||||
|
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||||
|
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||||
|
code = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
handler.CreateAndRedeem(c)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||||
|
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||||
|
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-default",
|
||||||
|
"value": 10.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"omitting type should default to balance and pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-no-group",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"validity_days": 30,
|
||||||
|
// group_id 缺失
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
validityDays int
|
||||||
|
}{
|
||||||
|
{"zero", 0},
|
||||||
|
{"negative", -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-bad-days-" + tc.name,
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": tc.validityDays,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-valid",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": 31,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"valid subscription params should pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-no-extras",
|
||||||
|
"type": "balance",
|
||||||
|
"value": 50.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"balance type should not require group_id or validity_days")
|
||||||
|
}
|
||||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
FrontendURL: settings.FrontendURL,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
|
|||||||
|
|
||||||
// 分组隔离
|
// 分组隔离
|
||||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||||
|
|
||||||
|
// Backend Mode
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Frontend URL 验证
|
||||||
|
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||||
|
if req.FrontendURL != "" {
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||||
|
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 自定义菜单项验证
|
// 自定义菜单项验证
|
||||||
const (
|
const (
|
||||||
maxCustomMenuItems = 20
|
maxCustomMenuItems = 20
|
||||||
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
|
FrontendURL: req.FrontendURL,
|
||||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: req.BackendModeEnabled,
|
||||||
OpsMonitoringEnabled: func() bool {
|
OpsMonitoringEnabled: func() bool {
|
||||||
if req.OpsMonitoringEnabled != nil {
|
if req.OpsMonitoringEnabled != nil {
|
||||||
return *req.OpsMonitoringEnabled
|
return *req.OpsMonitoringEnabled
|
||||||
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
|
FrontendURL: updatedSettings.FrontendURL,
|
||||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
TotpEnabled: updatedSettings.TotpEnabled,
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
|
if before.FrontendURL != after.FrontendURL {
|
||||||
|
changed = append(changed, "frontend_url")
|
||||||
|
}
|
||||||
if before.TotpEnabled != after.TotpEnabled {
|
if before.TotpEnabled != after.TotpEnabled {
|
||||||
changed = append(changed, "totp_enabled")
|
changed = append(changed, "totp_enabled")
|
||||||
}
|
}
|
||||||
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||||
}
|
}
|
||||||
|
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||||
|
changed = append(changed, "backend_mode_enabled")
|
||||||
|
}
|
||||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||||
changed = append(changed, "purchase_subscription_enabled")
|
changed = append(changed, "purchase_subscription_enabled")
|
||||||
}
|
}
|
||||||
@@ -952,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOverloadCooldownSettings 获取529过载冷却配置
|
||||||
|
// GET /api/v1/admin/settings/overload-cooldown
|
||||||
|
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.OverloadCooldownSettings{
|
||||||
|
Enabled: settings.Enabled,
|
||||||
|
CooldownMinutes: settings.CooldownMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
|
||||||
|
type UpdateOverloadCooldownSettingsRequest struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CooldownMinutes int `json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOverloadCooldownSettings 更新529过载冷却配置
|
||||||
|
// PUT /api/v1/admin/settings/overload-cooldown
|
||||||
|
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
|
||||||
|
var req UpdateOverloadCooldownSettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.OverloadCooldownSettings{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
CooldownMinutes: req.CooldownMinutes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.OverloadCooldownSettings{
|
||||||
|
Enabled: updatedSettings.Enabled,
|
||||||
|
CooldownMinutes: updatedSettings.CooldownMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// GetStreamTimeoutSettings 获取流超时处理配置
|
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||||
// GET /api/v1/admin/settings/stream-timeout
|
// GET /api/v1/admin/settings/stream-timeout
|
||||||
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||||
|
|||||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
|
platform := c.Query("platform")
|
||||||
|
|
||||||
// Parse sorting parameters
|
// Parse sorting parameters
|
||||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||||
|
|
||||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -218,11 +219,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
|
|
||||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||||
type ResetSubscriptionQuotaRequest struct {
|
type ResetSubscriptionQuotaRequest struct {
|
||||||
Daily bool `json:"daily"`
|
Daily bool `json:"daily"`
|
||||||
Weekly bool `json:"weekly"`
|
Weekly bool `json:"weekly"`
|
||||||
|
Monthly bool `json:"monthly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
@@ -235,11 +237,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !req.Daily && !req.Weekly {
|
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set end time to end of day
|
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
t = t.AddDate(0, 0, 1)
|
||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||||
|
endTime = endTime.AddDate(0, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
period := c.DefaultQuery("period", "today")
|
period := c.DefaultQuery("period", "today")
|
||||||
switch period {
|
switch period {
|
||||||
|
|||||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: only admin can login
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.respondWithTokenPair(c, user)
|
h.respondWithTokenPair(c, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the login session
|
// Get the user (before session deletion so we can check backend mode)
|
||||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
|
||||||
|
|
||||||
// Get the user
|
|
||||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the login session (only after all checks pass)
|
||||||
|
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||||
|
|
||||||
h.respondWithTokenPair(c, user)
|
h.respondWithTokenPair(c, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||||
if frontendBaseURL == "" {
|
if frontendBaseURL == "" {
|
||||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||||
response.InternalError(c, "Password reset is not configured")
|
response.InternalError(c, "Password reset is not configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: block non-admin token refresh
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, RefreshTokenResponse{
|
response.Success(c, RefreshTokenResponse{
|
||||||
AccessToken: tokenPair.AccessToken,
|
AccessToken: result.AccessToken,
|
||||||
RefreshToken: tokenPair.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
ExpiresIn: tokenPair.ExpiresIn,
|
ExpiresIn: result.ExpiresIn,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,15 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.MCPXMLInject,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
AccountCount: g.AccountCount,
|
||||||
AccountCount: g.AccountCount,
|
ActiveAccountCount: g.ActiveAccountCount,
|
||||||
SortOrder: g.SortOrder,
|
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||||
|
SortOrder: g.SortOrder,
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
@@ -265,8 +266,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||||
if a.Type == service.AccountTypeAPIKey {
|
if a.IsAPIKeyOrBedrock() {
|
||||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||||
out.QuotaLimit = &limit
|
out.QuotaLimit = &limit
|
||||||
used := a.GetQuotaUsed()
|
used := a.GetQuotaUsed()
|
||||||
@@ -282,6 +283,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
used := a.GetQuotaWeeklyUsed()
|
used := a.GetQuotaWeeklyUsed()
|
||||||
out.QuotaWeeklyUsed = &used
|
out.QuotaWeeklyUsed = &used
|
||||||
}
|
}
|
||||||
|
// 固定时间重置配置
|
||||||
|
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||||
|
out.QuotaDailyResetMode = &mode
|
||||||
|
hour := a.GetQuotaDailyResetHour()
|
||||||
|
out.QuotaDailyResetHour = &hour
|
||||||
|
}
|
||||||
|
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||||
|
out.QuotaWeeklyResetMode = &mode
|
||||||
|
day := a.GetQuotaWeeklyResetDay()
|
||||||
|
out.QuotaWeeklyResetDay = &day
|
||||||
|
hour := a.GetQuotaWeeklyResetHour()
|
||||||
|
out.QuotaWeeklyResetHour = &hour
|
||||||
|
}
|
||||||
|
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||||
|
tz := a.GetQuotaResetTimezone()
|
||||||
|
out.QuotaResetTimezone = &tz
|
||||||
|
}
|
||||||
|
if a.Extra != nil {
|
||||||
|
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||||
|
out.QuotaDailyResetAt = &v
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||||
|
out.QuotaWeeklyResetAt = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@@ -497,8 +523,11 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
AccountID: l.AccountID,
|
AccountID: l.AccountID,
|
||||||
RequestID: l.RequestID,
|
RequestID: l.RequestID,
|
||||||
Model: l.Model,
|
Model: l.Model,
|
||||||
|
UpstreamModel: l.UpstreamModel,
|
||||||
ServiceTier: l.ServiceTier,
|
ServiceTier: l.ServiceTier,
|
||||||
ReasoningEffort: l.ReasoningEffort,
|
ReasoningEffort: l.ReasoningEffort,
|
||||||
|
InboundEndpoint: l.InboundEndpoint,
|
||||||
|
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||||
GroupID: l.GroupID,
|
GroupID: l.GroupID,
|
||||||
SubscriptionID: l.SubscriptionID,
|
SubscriptionID: l.SubscriptionID,
|
||||||
InputTokens: l.InputTokens,
|
InputTokens: l.InputTokens,
|
||||||
|
|||||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
serviceTier := "priority"
|
serviceTier := "priority"
|
||||||
|
inboundEndpoint := "/v1/chat/completions"
|
||||||
|
upstreamEndpoint := "/v1/responses"
|
||||||
log := &service.UsageLog{
|
log := &service.UsageLog{
|
||||||
RequestID: "req_3",
|
RequestID: "req_3",
|
||||||
Model: "gpt-5.4",
|
Model: "gpt-5.4",
|
||||||
ServiceTier: &serviceTier,
|
ServiceTier: &serviceTier,
|
||||||
|
InboundEndpoint: &inboundEndpoint,
|
||||||
|
UpstreamEndpoint: &upstreamEndpoint,
|
||||||
AccountRateMultiplier: f64Ptr(1.5),
|
AccountRateMultiplier: f64Ptr(1.5),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
|
|
||||||
require.NotNil(t, userDTO.ServiceTier)
|
require.NotNil(t, userDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||||
|
require.NotNil(t, userDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.ServiceTier)
|
require.NotNil(t, adminDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||||
|
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
@@ -81,6 +82,9 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
// 分组隔离
|
// 分组隔离
|
||||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||||
|
|
||||||
|
// Backend Mode
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultSubscriptionSetting struct {
|
type DefaultSubscriptionSetting struct {
|
||||||
@@ -111,6 +115,7 @@ type PublicSettings struct {
|
|||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
|
|||||||
Items []SoraS3Profile `json:"items"`
|
Items []SoraS3Profile `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||||
|
type OverloadCooldownSettings struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CooldownMinutes int `json:"cooldown_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
type StreamTimeoutSettings struct {
|
type StreamTimeoutSettings struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@@ -117,16 +117,16 @@ type AdminGroup struct {
|
|||||||
|
|
||||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
// Claude usage 模拟开关(仅管理员可见)
|
|
||||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
|
||||||
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
AccountCount int64 `json:"account_count,omitempty"`
|
AccountCount int64 `json:"account_count,omitempty"`
|
||||||
|
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||||
|
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||||
|
|
||||||
// 分组排序
|
// 分组排序
|
||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
@@ -205,6 +205,16 @@ type Account struct {
|
|||||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||||
|
|
||||||
|
// 配额固定时间重置配置
|
||||||
|
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||||
|
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||||
|
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||||
|
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||||
|
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||||
|
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||||
|
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||||
|
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||||
|
|
||||||
Proxy *Proxy `json:"proxy,omitempty"`
|
Proxy *Proxy `json:"proxy,omitempty"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
|
||||||
@@ -324,11 +334,18 @@ type UsageLog struct {
|
|||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||||
|
// Omitted when no mapping was applied (requested model was used as-is).
|
||||||
|
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||||
ServiceTier *string `json:"service_tier,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty"`
|
||||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
// ReasoningEffort is the request's reasoning effort level.
|
||||||
// nil means not provided / not applicable.
|
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||||
|
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||||
|
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||||
|
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||||
|
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||||
|
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
SubscriptionID *int64 `json:"subscription_id"`
|
SubscriptionID *int64 `json:"subscription_id"`
|
||||||
|
|||||||
174
backend/internal/handler/endpoint.go
Normal file
174
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Canonical inbound / upstream endpoint paths.
|
||||||
|
// All normalization and derivation reference this single set
|
||||||
|
// of constants — add new paths HERE when a new API surface
|
||||||
|
// is introduced.
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointMessages = "/v1/messages"
|
||||||
|
EndpointChatCompletions = "/v1/chat/completions"
|
||||||
|
EndpointResponses = "/v1/responses"
|
||||||
|
EndpointGeminiModels = "/v1beta/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// gin.Context keys used by the middleware and helpers below.
|
||||||
|
const (
|
||||||
|
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Normalization functions
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||||
|
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||||
|
//
|
||||||
|
// "/antigravity/v1/messages" → "/v1/messages"
|
||||||
|
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||||
|
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||||
|
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||||
|
func NormalizeInboundEndpoint(path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(path, EndpointChatCompletions):
|
||||||
|
return EndpointChatCompletions
|
||||||
|
case strings.Contains(path, EndpointMessages):
|
||||||
|
return EndpointMessages
|
||||||
|
case strings.Contains(path, EndpointResponses):
|
||||||
|
return EndpointResponses
|
||||||
|
case strings.Contains(path, EndpointGeminiModels):
|
||||||
|
return EndpointGeminiModels
|
||||||
|
default:
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||||
|
// account platform and the normalized inbound endpoint.
|
||||||
|
//
|
||||||
|
// Platform-specific rules:
|
||||||
|
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||||
|
// such as /v1/responses/compact preserved from the raw URL).
|
||||||
|
// - Anthropic → /v1/messages
|
||||||
|
// - Gemini → /v1beta/models
|
||||||
|
// - Sora → /v1/chat/completions
|
||||||
|
// - Antigravity routes may target either Claude or Gemini, so the
|
||||||
|
// inbound endpoint is used to distinguish.
|
||||||
|
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||||
|
inbound = strings.TrimSpace(inbound)
|
||||||
|
|
||||||
|
switch platform {
|
||||||
|
case service.PlatformOpenAI:
|
||||||
|
// OpenAI forwards everything to the Responses API.
|
||||||
|
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||||
|
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||||
|
return EndpointResponses + suffix
|
||||||
|
}
|
||||||
|
return EndpointResponses
|
||||||
|
|
||||||
|
case service.PlatformAnthropic:
|
||||||
|
return EndpointMessages
|
||||||
|
|
||||||
|
case service.PlatformGemini:
|
||||||
|
return EndpointGeminiModels
|
||||||
|
|
||||||
|
case service.PlatformSora:
|
||||||
|
return EndpointChatCompletions
|
||||||
|
|
||||||
|
case service.PlatformAntigravity:
|
||||||
|
// Antigravity accounts serve both Claude and Gemini.
|
||||||
|
if inbound == EndpointGeminiModels {
|
||||||
|
return EndpointGeminiModels
|
||||||
|
}
|
||||||
|
return EndpointMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown platform — fall back to inbound.
|
||||||
|
return inbound
|
||||||
|
}
|
||||||
|
|
||||||
|
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||||
|
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||||
|
// Returns "" when there is no meaningful suffix.
|
||||||
|
func responsesSubpathSuffix(rawPath string) string {
|
||||||
|
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||||
|
idx := strings.LastIndex(trimmed, "/responses")
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
suffix := trimmed[idx+len("/responses"):]
|
||||||
|
if suffix == "" || suffix == "/" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(suffix, "/") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Middleware
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||||
|
// canonical inbound endpoint in gin.Context so that every handler in
|
||||||
|
// the chain can read it via GetInboundEndpoint.
|
||||||
|
//
|
||||||
|
// Apply this middleware to all gateway route groups.
|
||||||
|
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
path := c.FullPath()
|
||||||
|
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
path = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// Context helpers — used by handlers before building
|
||||||
|
// RecordUsageInput / RecordUsageLongContextInput.
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||||
|
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||||
|
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||||
|
func GetInboundEndpoint(c *gin.Context) string {
|
||||||
|
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: normalize on the fly.
|
||||||
|
path := ""
|
||||||
|
if c != nil {
|
||||||
|
path = c.FullPath()
|
||||||
|
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
path = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NormalizeInboundEndpoint(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||||
|
// and the account platform. Handlers call this after scheduling an
|
||||||
|
// account, passing account.Platform.
|
||||||
|
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||||
|
inbound := GetInboundEndpoint(c)
|
||||||
|
rawPath := ""
|
||||||
|
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||||
|
rawPath = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||||
|
}
|
||||||
159
backend/internal/handler/endpoint_test.go
Normal file
159
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() { gin.SetMode(gin.TestMode) }
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// NormalizeInboundEndpoint
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Direct canonical paths.
|
||||||
|
{"/v1/messages", EndpointMessages},
|
||||||
|
{"/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/v1/responses", EndpointResponses},
|
||||||
|
{"/v1beta/models", EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Prefixed paths (antigravity, openai, sora).
|
||||||
|
{"/antigravity/v1/messages", EndpointMessages},
|
||||||
|
{"/openai/v1/responses", EndpointResponses},
|
||||||
|
{"/openai/v1/responses/compact", EndpointResponses},
|
||||||
|
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Gin route patterns with wildcards.
|
||||||
|
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||||
|
{"/v1/responses/*subpath", EndpointResponses},
|
||||||
|
|
||||||
|
// Unknown path is returned as-is.
|
||||||
|
{"/v1/embeddings", "/v1/embeddings"},
|
||||||
|
{"", ""},
|
||||||
|
{" /v1/messages ", EndpointMessages},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.path, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// DeriveUpstreamEndpoint
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inbound string
|
||||||
|
rawPath string
|
||||||
|
platform string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Anthropic.
|
||||||
|
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||||
|
|
||||||
|
// Gemini.
|
||||||
|
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Sora.
|
||||||
|
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||||
|
|
||||||
|
// OpenAI — always /v1/responses.
|
||||||
|
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||||
|
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||||
|
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
|
||||||
|
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||||
|
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||||
|
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||||
|
|
||||||
|
// Unknown platform — passthrough.
|
||||||
|
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// responsesSubpathSuffix
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
raw string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/v1/responses", ""},
|
||||||
|
{"/v1/responses/", ""},
|
||||||
|
{"/v1/responses/compact", "/compact"},
|
||||||
|
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||||
|
{"/v1/messages", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.raw, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
// InboundEndpointMiddleware + context helpers
|
||||||
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(InboundEndpointMiddleware())
|
||||||
|
|
||||||
|
var captured string
|
||||||
|
router.POST("/v1/messages", func(c *gin.Context) {
|
||||||
|
captured = GetInboundEndpoint(c)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, EndpointMessages, captured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||||
|
|
||||||
|
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||||
|
got := GetInboundEndpoint(c)
|
||||||
|
require.Equal(t, EndpointMessages, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||||
|
|
||||||
|
// Simulate middleware.
|
||||||
|
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||||
|
|
||||||
|
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||||
|
require.Equal(t, "/v1/responses/compact", got)
|
||||||
|
}
|
||||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -434,20 +441,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
ParsedRequest: parsedReq,
|
APIKey: apiKey,
|
||||||
APIKey: apiKey,
|
User: apiKey.User,
|
||||||
User: apiKey.User,
|
Account: account,
|
||||||
Account: account,
|
Subscription: subscription,
|
||||||
Subscription: subscription,
|
InboundEndpoint: inboundEndpoint,
|
||||||
UserAgent: userAgent,
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
IPAddress: clientIP,
|
UserAgent: userAgent,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
@@ -631,12 +647,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// ===== 用户消息串行队列 END =====
|
// ===== 用户消息串行队列 END =====
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
c.Set("parsed_request", parsedReq)
|
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -706,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -738,20 +760,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
ParsedRequest: parsedReq,
|
APIKey: currentAPIKey,
|
||||||
APIKey: currentAPIKey,
|
User: currentAPIKey.User,
|
||||||
User: currentAPIKey.User,
|
Account: account,
|
||||||
Account: account,
|
Subscription: currentSubscription,
|
||||||
Subscription: currentSubscription,
|
InboundEndpoint: inboundEndpoint,
|
||||||
UserAgent: userAgent,
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
IPAddress: clientIP,
|
UserAgent: userAgent,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
@@ -912,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
|||||||
}
|
}
|
||||||
if s := c.Query("end_date"); s != "" {
|
if s := c.Query("end_date"); s != "" {
|
||||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return startTime, endTime
|
return startTime, endTime
|
||||||
@@ -1188,6 +1219,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1196,6 +1231,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||||
|
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||||
|
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||||
|
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||||
|
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||||
|
// 具体验证:
|
||||||
|
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||||
|
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||||
|
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||||
|
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||||
|
|
||||||
|
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||||
|
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||||
|
|
||||||
|
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||||
|
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||||
|
|
||||||
|
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||||
|
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||||
|
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||||
|
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||||
|
|
||||||
|
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx,
|
||||||
|
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||||
|
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||||
|
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||||
|
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||||
|
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
require.Contains(t, body, "event: message_start")
|
||||||
|
require.Contains(t, body, `"type":"error"`)
|
||||||
|
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||||
|
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||||
|
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||||
|
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||||
|
// c.Writer.Size() 仍为 -1
|
||||||
|
|
||||||
|
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||||
|
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||||
|
require.False(t, guardTriggered,
|
||||||
|
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||||
|
}
|
||||||
@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
|
||||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||||
&fakeGroupRepo{group: group},
|
&fakeGroupRepo{group: group},
|
||||||
nil, // usageLogRepo
|
nil, // usageLogRepo
|
||||||
|
nil, // usageBillingRepo
|
||||||
nil, // userRepo
|
nil, // userRepo
|
||||||
nil, // userSubRepo
|
nil, // userSubRepo
|
||||||
nil, // userGroupRateRepo
|
nil, // userGroupRateRepo
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
|||||||
return []byte(`{
|
return []byte(`{
|
||||||
"model":"claude-3-5-sonnet-20241022",
|
"model":"claude-3-5-sonnet-20241022",
|
||||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||||
}`)
|
}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
System: []any{
|
System: []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||||
}
|
}
|
||||||
|
|
||||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||||
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
|||||||
"system": []any{
|
"system": []any{
|
||||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
},
|
},
|
||||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||||
})
|
})
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||||
|
|||||||
@@ -503,6 +503,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
@@ -510,8 +513,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
@@ -587,6 +593,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, message := mapGeminiUpstreamError(statusCode)
|
status, message := mapGeminiUpstreamError(statusCode)
|
||||||
googleError(c, status, message)
|
googleError(c, status, message)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
|||||||
Account *admin.AccountHandler
|
Account *admin.AccountHandler
|
||||||
Announcement *admin.AnnouncementHandler
|
Announcement *admin.AnnouncementHandler
|
||||||
DataManagement *admin.DataManagementHandler
|
DataManagement *admin.DataManagementHandler
|
||||||
|
Backup *admin.BackupHandler
|
||||||
OAuth *admin.OAuthHandler
|
OAuth *admin.OAuthHandler
|
||||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||||
GeminiOAuth *admin.GeminiOAuthHandler
|
GeminiOAuth *admin.GeminiOAuthHandler
|
||||||
|
|||||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := ""
|
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||||
if apiKey.Group != nil {
|
|
||||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
|
||||||
}
|
|
||||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
|
||||||
defaultMappedModel = fallbackModel
|
|
||||||
}
|
|
||||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
@@ -262,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||||
|
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||||
|
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||||
|
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "responses root maps to responses upstream",
|
||||||
|
path: "/v1/responses",
|
||||||
|
want: EndpointResponses,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses compact keeps compact suffix",
|
||||||
|
path: "/openai/v1/responses/compact",
|
||||||
|
want: "/v1/responses/compact",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses nested suffix preserved",
|
||||||
|
path: "/openai/v1/responses/compact/detail",
|
||||||
|
want: "/v1/responses/compact/detail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non responses path uses platform fallback",
|
||||||
|
path: "/v1/messages",
|
||||||
|
want: EndpointResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||||
|
|
||||||
|
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -352,18 +352,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.responses"),
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
@@ -653,14 +657,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := ""
|
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||||
if apiKey.Group != nil {
|
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||||
}
|
|
||||||
// 如果使用了降级模型调度,强制使用降级模型
|
|
||||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
|
||||||
defaultMappedModel = fallbackModel
|
|
||||||
}
|
|
||||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
@@ -732,17 +731,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.messages"),
|
zap.String("component", "handler.openai_gateway.messages"),
|
||||||
@@ -1231,14 +1234,17 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
APIKeyService: h.apiKeyService,
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
@@ -1429,6 +1435,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
// 使用默认的错误映射
|
// 使用默认的错误映射
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
@@ -1437,6 +1447,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,22 @@ const (
|
|||||||
opsStreamKey = "ops_stream"
|
opsStreamKey = "ops_stream"
|
||||||
opsRequestBodyKey = "ops_request_body"
|
opsRequestBodyKey = "ops_request_body"
|
||||||
opsAccountIDKey = "ops_account_id"
|
opsAccountIDKey = "ops_account_id"
|
||||||
|
|
||||||
|
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||||
|
opsErrContextCanceled = "context canceled"
|
||||||
|
opsErrNoAvailableAccounts = "no available accounts"
|
||||||
|
opsErrInvalidAPIKey = "invalid_api_key"
|
||||||
|
opsErrAPIKeyRequired = "api_key_required"
|
||||||
|
opsErrInsufficientBalance = "insufficient balance"
|
||||||
|
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||||
|
opsErrInsufficientQuota = "insufficient_quota"
|
||||||
|
|
||||||
|
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||||
|
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||||
|
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||||
|
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||||
|
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||||
|
opsCodeUserInactive = "USER_INACTIVE"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
|||||||
return errType
|
return errType
|
||||||
}
|
}
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE":
|
case opsCodeInsufficientBalance:
|
||||||
return "billing_error"
|
return "billing_error"
|
||||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "subscription_error"
|
return "subscription_error"
|
||||||
default:
|
default:
|
||||||
return "api_error"
|
return "api_error"
|
||||||
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||||
// Map billing/concurrency/response => request; scheduling => routing.
|
// Map billing/concurrency/response => request; scheduling => routing.
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "request"
|
return "request"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
case "upstream_error", "overloaded_error":
|
case "upstream_error", "overloaded_error":
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "api_error":
|
case "api_error":
|
||||||
if strings.Contains(msg, "no available accounts") {
|
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||||
return "routing"
|
return "routing"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "internal"
|
||||||
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
|||||||
|
|
||||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if phase == "billing" || phase == "concurrency" {
|
if phase == "billing" || phase == "concurrency" {
|
||||||
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
|||||||
|
|
||||||
// Check if context canceled errors should be ignored (client disconnects)
|
// Check if context canceled errors should be ignored (client disconnects)
|
||||||
if settings.IgnoreContextCanceled {
|
if settings.IgnoreContextCanceled {
|
||||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if "no available accounts" errors should be ignored
|
// Check if "no available accounts" errors should be ignored
|
||||||
if settings.IgnoreNoAvailableAccounts {
|
if settings.IgnoreNoAvailableAccounts {
|
||||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||||
if settings.IgnoreInvalidApiKeyErrors {
|
if settings.IgnoreInvalidApiKeyErrors {
|
||||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if insufficient balance errors should be ignored
|
||||||
|
if settings.IgnoreInsufficientBalanceErrors {
|
||||||
|
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||||
|
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||||
|
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
|||||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||||
return service.NewGatewayService(
|
return service.NewGatewayService(
|
||||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -399,17 +399,23 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
InboundEndpoint: inboundEndpoint,
|
||||||
IPAddress: clientIP,
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||||
@@ -478,6 +484,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||||
|
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||||
|
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||||
|
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
|
|||||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||||
return 0, nil
|
return 0, 0, nil
|
||||||
}
|
}
|
||||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@@ -334,15 +334,32 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
|||||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -431,6 +448,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
testutil.StubGatewayCache{},
|
testutil.StubGatewayCache{},
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set end time to end of day
|
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
t = t.AddDate(0, 0, 1)
|
||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 设置结束时间为当天结束
|
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
endTime = endTime.AddDate(0, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
// 使用 period 参数
|
// 使用 period 参数
|
||||||
period := c.DefaultQuery("period", "today")
|
period := c.DefaultQuery("period", "today")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
|||||||
accountHandler *admin.AccountHandler,
|
accountHandler *admin.AccountHandler,
|
||||||
announcementHandler *admin.AnnouncementHandler,
|
announcementHandler *admin.AnnouncementHandler,
|
||||||
dataManagementHandler *admin.DataManagementHandler,
|
dataManagementHandler *admin.DataManagementHandler,
|
||||||
|
backupHandler *admin.BackupHandler,
|
||||||
oauthHandler *admin.OAuthHandler,
|
oauthHandler *admin.OAuthHandler,
|
||||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
|||||||
Account: accountHandler,
|
Account: accountHandler,
|
||||||
Announcement: announcementHandler,
|
Announcement: announcementHandler,
|
||||||
DataManagement: dataManagementHandler,
|
DataManagement: dataManagementHandler,
|
||||||
|
Backup: backupHandler,
|
||||||
OAuth: oauthHandler,
|
OAuth: oauthHandler,
|
||||||
OpenAIOAuth: openaiOAuthHandler,
|
OpenAIOAuth: openaiOAuthHandler,
|
||||||
GeminiOAuth: geminiOAuthHandler,
|
GeminiOAuth: geminiOAuthHandler,
|
||||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewAccountHandler,
|
admin.NewAccountHandler,
|
||||||
admin.NewAnnouncementHandler,
|
admin.NewAnnouncementHandler,
|
||||||
admin.NewDataManagementHandler,
|
admin.NewDataManagementHandler,
|
||||||
|
admin.NewBackupHandler,
|
||||||
admin.NewOAuthHandler,
|
admin.NewOAuthHandler,
|
||||||
admin.NewOpenAIOAuthHandler,
|
admin.NewOpenAIOAuthHandler,
|
||||||
admin.NewGeminiOAuthHandler,
|
admin.NewGeminiOAuthHandler,
|
||||||
|
|||||||
@@ -19,6 +19,16 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ForbiddenError 表示上游返回 403 Forbidden
|
||||||
|
type ForbiddenError struct {
|
||||||
|
StatusCode int
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ForbiddenError) Error() string {
|
||||||
|
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||||
|
}
|
||||||
|
|
||||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||||
@@ -114,10 +124,68 @@ type IneligibleTier struct {
|
|||||||
type LoadCodeAssistResponse struct {
|
type LoadCodeAssistResponse struct {
|
||||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
|
||||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
|
||||||
|
type PaidTierInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
|
||||||
|
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 || string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] == '"' {
|
||||||
|
var id string
|
||||||
|
if err := json.Unmarshal(data, &id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.ID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
type alias PaidTierInfo
|
||||||
|
var raw alias
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*p = PaidTierInfo(raw)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableCredit 表示一条 AI Credits 余额记录。
|
||||||
|
type AvailableCredit struct {
|
||||||
|
CreditType string `json:"creditType,omitempty"`
|
||||||
|
CreditAmount string `json:"creditAmount,omitempty"`
|
||||||
|
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmount 将 creditAmount 解析为浮点数。
|
||||||
|
func (c *AvailableCredit) GetAmount() float64 {
|
||||||
|
if c.CreditAmount == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
|
||||||
|
func (c *AvailableCredit) GetMinimumAmount() float64 {
|
||||||
|
if c.MinimumCreditAmountForUsage == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// OnboardUserRequest onboardUser 请求
|
// OnboardUserRequest onboardUser 请求
|
||||||
type OnboardUserRequest struct {
|
type OnboardUserRequest struct {
|
||||||
TierID string `json:"tierId"`
|
TierID string `json:"tierId"`
|
||||||
@@ -147,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
|
||||||
|
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
|
||||||
|
if r.PaidTier == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.PaidTier.AvailableCredits
|
||||||
|
}
|
||||||
|
|
||||||
// Client Antigravity API 客户端
|
// Client Antigravity API 客户端
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
@@ -514,7 +590,20 @@ type ModelQuotaInfo struct {
|
|||||||
|
|
||||||
// ModelInfo 模型信息
|
// ModelInfo 模型信息
|
||||||
type ModelInfo struct {
|
type ModelInfo struct {
|
||||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||||
|
DisplayName string `json:"displayName,omitempty"`
|
||||||
|
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||||
|
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||||
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
|
Recommended *bool `json:"recommended,omitempty"`
|
||||||
|
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecatedModelInfo 废弃模型转发信息
|
||||||
|
type DeprecatedModelInfo struct {
|
||||||
|
NewModelID string `json:"newModelId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||||
@@ -524,7 +613,8 @@ type FetchAvailableModelsRequest struct {
|
|||||||
|
|
||||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||||
type FetchAvailableModelsResponse struct {
|
type FetchAvailableModelsResponse struct {
|
||||||
Models map[string]ModelInfo `json:"models"`
|
Models map[string]ModelInfo `json:"models"`
|
||||||
|
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||||
@@ -573,6 +663,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, nil, &ForbiddenError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Body: string(respBodyBytes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
|
|||||||
func TestGetTier_PaidTier优先(t *testing.T) {
|
func TestGetTier_PaidTier优先(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{
|
resp := &LoadCodeAssistResponse{
|
||||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||||
PaidTier: &TierInfo{ID: "g1-pro-tier"},
|
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
|
||||||
}
|
}
|
||||||
if got := resp.GetTier(); got != "g1-pro-tier" {
|
if got := resp.GetTier(); got != "g1-pro-tier" {
|
||||||
t.Errorf("应返回 paidTier: got %s", got)
|
t.Errorf("应返回 paidTier: got %s", got)
|
||||||
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
|
|||||||
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{
|
resp := &LoadCodeAssistResponse{
|
||||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||||
PaidTier: &TierInfo{ID: ""},
|
PaidTier: &PaidTierInfo{ID: ""},
|
||||||
}
|
}
|
||||||
// paidTier.ID 为空时应回退到 currentTier
|
// paidTier.ID 为空时应回退到 currentTier
|
||||||
if got := resp.GetTier(); got != "free-tier" {
|
if got := resp.GetTier(); got != "free-tier" {
|
||||||
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableCredits(t *testing.T) {
|
||||||
|
resp := &LoadCodeAssistResponse{
|
||||||
|
PaidTier: &PaidTierInfo{
|
||||||
|
ID: "g1-pro-tier",
|
||||||
|
AvailableCredits: []AvailableCredit{
|
||||||
|
{
|
||||||
|
CreditType: "GOOGLE_ONE_AI",
|
||||||
|
CreditAmount: "25",
|
||||||
|
MinimumCreditAmountForUsage: "5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
credits := resp.GetAvailableCredits()
|
||||||
|
if len(credits) != 1 {
|
||||||
|
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
|
||||||
|
}
|
||||||
|
if credits[0].GetAmount() != 25 {
|
||||||
|
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
|
||||||
|
}
|
||||||
|
if credits[0].GetMinimumAmount() != 5 {
|
||||||
|
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetTier_两者都为nil(t *testing.T) {
|
func TestGetTier_两者都为nil(t *testing.T) {
|
||||||
resp := &LoadCodeAssistResponse{}
|
resp := &LoadCodeAssistResponse{}
|
||||||
if got := resp.GetTier(); got != "" {
|
if got := resp.GetTier(); got != "" {
|
||||||
|
|||||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
|||||||
"<|user|>",
|
"<|user|>",
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
"<|end_of_turn|>",
|
"<|end_of_turn|>",
|
||||||
"[DONE]",
|
|
||||||
"\n\nHuman:",
|
"\n\nHuman:",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ const (
|
|||||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||||
var defaultUserAgentVersion = "1.20.4"
|
var defaultUserAgentVersion = "1.20.5"
|
||||||
|
|
||||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
var defaultClientSecret = "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"
|
var defaultClientSecret = "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"
|
||||||
|
|||||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||||
}
|
}
|
||||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||||
}
|
}
|
||||||
if SessionTTL != 30*time.Minute {
|
if SessionTTL != 30*time.Minute {
|
||||||
|
|||||||
@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
|
||||||
var systemBlockFilterPrefixes = []string{
|
|
||||||
"x-anthropic-billing-header",
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
|
||||||
func filterSystemBlockByPrefix(text string) string {
|
|
||||||
for _, prefix := range systemBlockFilterPrefixes {
|
|
||||||
if strings.HasPrefix(text, prefix) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(sysStr, "You are Antigravity") {
|
if strings.Contains(sysStr, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
filtered := filterOpenCodePrompt(sysStr)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(block.Text, "You are Antigravity") {
|
if strings.Contains(block.Text, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
// 过滤 OpenCode 默认提示词
|
||||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
filtered := filterOpenCodePrompt(block.Text)
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package antigravity
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||||
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
system json.RawMessage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "system array",
|
||||||
|
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "system string",
|
||||||
|
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claudeReq := &ClaudeRequest{
|
||||||
|
Model: "claude-3-5-sonnet-latest",
|
||||||
|
System: tt.system,
|
||||||
|
Messages: []ClaudeMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var req V1InternalRequest
|
||||||
|
require.NoError(t, json.Unmarshal(body, &req))
|
||||||
|
require.NotNil(t, req.Request.SystemInstruction)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, part := range req.Request.SystemInstruction.Parts {
|
||||||
|
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,9 +18,6 @@ const (
|
|||||||
BlockTypeFunction
|
BlockTypeFunction
|
||||||
)
|
)
|
||||||
|
|
||||||
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
|
||||||
type UsageMapHook func(usageMap map[string]any)
|
|
||||||
|
|
||||||
// StreamingProcessor 流式响应处理器
|
// StreamingProcessor 流式响应处理器
|
||||||
type StreamingProcessor struct {
|
type StreamingProcessor struct {
|
||||||
blockType BlockType
|
blockType BlockType
|
||||||
@@ -33,7 +30,6 @@ type StreamingProcessor struct {
|
|||||||
originalModel string
|
originalModel string
|
||||||
webSearchQueries []string
|
webSearchQueries []string
|
||||||
groundingChunks []GeminiGroundingChunk
|
groundingChunks []GeminiGroundingChunk
|
||||||
usageMapHook UsageMapHook
|
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@@ -49,25 +45,6 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
|
||||||
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
|
||||||
p.usageMapHook = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
func usageToMap(u ClaudeUsage) map[string]any {
|
|
||||||
m := map[string]any{
|
|
||||||
"input_tokens": u.InputTokens,
|
|
||||||
"output_tokens": u.OutputTokens,
|
|
||||||
}
|
|
||||||
if u.CacheCreationInputTokens > 0 {
|
|
||||||
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
|
||||||
}
|
|
||||||
if u.CacheReadInputTokens > 0 {
|
|
||||||
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
@@ -191,13 +168,6 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
responseID = "msg_" + generateRandomID()
|
responseID = "msg_" + generateRandomID()
|
||||||
}
|
}
|
||||||
|
|
||||||
var usageValue any = usage
|
|
||||||
if p.usageMapHook != nil {
|
|
||||||
usageMap := usageToMap(usage)
|
|
||||||
p.usageMapHook(usageMap)
|
|
||||||
usageValue = usageMap
|
|
||||||
}
|
|
||||||
|
|
||||||
message := map[string]any{
|
message := map[string]any{
|
||||||
"id": responseID,
|
"id": responseID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
@@ -206,7 +176,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
"model": p.originalModel,
|
"model": p.originalModel,
|
||||||
"stop_reason": nil,
|
"stop_reason": nil,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
"usage": usageValue,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
event := map[string]any{
|
event := map[string]any{
|
||||||
@@ -517,20 +487,13 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
|||||||
CacheReadInputTokens: p.cacheReadTokens,
|
CacheReadInputTokens: p.cacheReadTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
var usageValue any = usage
|
|
||||||
if p.usageMapHook != nil {
|
|
||||||
usageMap := usageToMap(usage)
|
|
||||||
p.usageMapHook(usageMap)
|
|
||||||
usageValue = usageMap
|
|
||||||
}
|
|
||||||
|
|
||||||
deltaEvent := map[string]any{
|
deltaEvent := map[string]any{
|
||||||
"type": "message_delta",
|
"type": "message_delta",
|
||||||
"delta": map[string]any{
|
"delta": map[string]any{
|
||||||
"stop_reason": stopReason,
|
"stop_reason": stopReason,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
},
|
},
|
||||||
"usage": usageValue,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
|||||||
assert.Equal(t, "assistant", items[1].Role)
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
assert.Equal(t, "function_call", items[2].Type)
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||||
|
assert.Empty(t, items[2].ID)
|
||||||
assert.Equal(t, "function_call_output", items[3].Type)
|
assert.Equal(t, "function_call_output", items[3].Type)
|
||||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||||
@@ -1007,3 +1008,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
|||||||
// Should default to image/png when media_type is empty.
|
// Should default to image/png when media_type is empty.
|
||||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// normalizeToolParameters tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNormalizeToolParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input json.RawMessage
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input",
|
||||||
|
input: nil,
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: json.RawMessage(``),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null input",
|
||||||
|
input: json.RawMessage(`null`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object without properties",
|
||||||
|
input: json.RawMessage(`{"type":"object"}`),
|
||||||
|
expected: `{"type":"object","properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with properties",
|
||||||
|
input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||||
|
expected: `{"type":"object","properties":{"city":{"type":"string"}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-object type",
|
||||||
|
input: json.RawMessage(`{"type":"string"}`),
|
||||||
|
expected: `{"type":"string"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object with additional fields preserved",
|
||||||
|
input: json.RawMessage(`{"type":"object","required":["name"]}`),
|
||||||
|
expected: `{"type":"object","required":["name"],"properties":{}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON passthrough",
|
||||||
|
input: json.RawMessage(`not json`),
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := normalizeToolParameters(tt.input)
|
||||||
|
if tt.name == "invalid JSON passthrough" {
|
||||||
|
assert.Equal(t, tt.expected, string(result))
|
||||||
|
} else {
|
||||||
|
assert.JSONEq(t, tt.expected, string(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||||
|
assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name)
|
||||||
|
|
||||||
|
// Parameters must have "properties" field after normalization.
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.Contains(t, params, "properties")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
|
||||||
|
req := &AnthropicRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []AnthropicMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||||
|
},
|
||||||
|
Tools: []AnthropicTool{
|
||||||
|
{Name: "simple_tool", Description: "A tool"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := AnthropicToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, resp.Tools, 1)
|
||||||
|
var params map[string]json.RawMessage
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms))
|
||||||
|
assert.JSONEq(t, `"object"`, string(params["type"]))
|
||||||
|
assert.JSONEq(t, `{}`, string(params["properties"]))
|
||||||
|
}
|
||||||
|
|||||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
|||||||
CallID: fcID,
|
CallID: fcID,
|
||||||
Name: b.Name,
|
Name: b.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
ID: fcID,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,8 +409,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
|||||||
Type: "function",
|
Type: "function",
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Description: t.Description,
|
Description: t.Description,
|
||||||
Parameters: t.InputSchema,
|
Parameters: normalizeToolParameters(t.InputSchema),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||||
|
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||||
|
//
|
||||||
|
// - nil/empty → {"type":"object","properties":{}}
|
||||||
|
// - type=object without properties → adds "properties": {}
|
||||||
|
// - otherwise → returned unchanged
|
||||||
|
func normalizeToolParameters(schema json.RawMessage) json.RawMessage {
|
||||||
|
if len(schema) == 0 || string(schema) == "null" {
|
||||||
|
return json.RawMessage(`{"type":"object","properties":{}}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(schema, &m); err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := m["type"]
|
||||||
|
if string(typ) != `"object"` {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["properties"]; ok {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
m["properties"] = json.RawMessage(`{}`)
|
||||||
|
out, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
|||||||
// Check function_call item
|
// Check function_call item
|
||||||
assert.Equal(t, "function_call", items[1].Type)
|
assert.Equal(t, "function_call", items[1].Type)
|
||||||
assert.Equal(t, "call_1", items[1].CallID)
|
assert.Equal(t, "call_1", items[1].CallID)
|
||||||
|
assert.Empty(t, items[1].ID)
|
||||||
assert.Equal(t, "ping", items[1].Name)
|
assert.Equal(t, "ping", items[1].Name)
|
||||||
|
|
||||||
// Check function_call_output item
|
// Check function_call_output item
|
||||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
|||||||
assert.Equal(t, "user", items[0].Role)
|
assert.Equal(t, "user", items[0].Role)
|
||||||
assert.Equal(t, "assistant", items[1].Role)
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
assert.Equal(t, "function_call", items[2].Type)
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
|
assert.Empty(t, items[2].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
assert.Equal(t, "output_text", parts[0].Type)
|
||||||
|
assert.Equal(t, "AB", parts[0].Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
assert.Equal(t, "output_text", parts[0].Type)
|
||||||
|
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||||
|
assert.Contains(t, parts[0].Text, "final answer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
|||||||
|
|
||||||
var content string
|
var content string
|
||||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
// Reasoning summary is prepended to text
|
assert.Equal(t, "The answer is 42.", content)
|
||||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
|||||||
Delta: "Thinking...",
|
Delta: "Thinking...",
|
||||||
}, state)
|
}, state)
|
||||||
require.Len(t, chunks, 1)
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.done",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.delta",
|
||||||
|
Delta: "plan",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: "answer",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package apicompat
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
|||||||
|
|
||||||
// Emit assistant message with output_text if content is non-empty.
|
// Emit assistant message with output_text if content is non-empty.
|
||||||
if len(m.Content) > 0 {
|
if len(m.Content) > 0 {
|
||||||
var s string
|
s, err := parseAssistantContent(m.Content)
|
||||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s != "" {
|
||||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||||
partsJSON, err := json.Marshal(parts)
|
partsJSON, err := json.Marshal(parts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
|||||||
CallID: tc.ID,
|
CallID: tc.ID,
|
||||||
Name: tc.Function.Name,
|
Name: tc.Function.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
ID: tc.ID,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseAssistantContent returns assistant content as plain text.
|
||||||
|
//
|
||||||
|
// Supported formats:
|
||||||
|
// - JSON string
|
||||||
|
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||||
|
//
|
||||||
|
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||||
|
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||||
|
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||||
|
// Keep compatibility with prior behavior: unsupported assistant content
|
||||||
|
// formats are ignored instead of failing the whole request conversion.
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
write := func(v string) error {
|
||||||
|
_, err := b.WriteString(v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, p := range parts {
|
||||||
|
typ, _ := p["type"].(string)
|
||||||
|
text, _ := p["text"].(string)
|
||||||
|
thinking, _ := p["thinking"].(string)
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case "thinking", "reasoning":
|
||||||
|
if thinking != "" {
|
||||||
|
if err := write("<thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write(thinking); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write("</thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if text != "" {
|
||||||
|
if err := write("<thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write(text); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write("</thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text != "" {
|
||||||
|
if err := write(text); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||||
// function_call_output item.
|
// function_call_output item.
|
||||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
}
|
}
|
||||||
|
|
||||||
var contentText string
|
var contentText string
|
||||||
|
var reasoningText string
|
||||||
var toolCalls []ChatToolCall
|
var toolCalls []ChatToolCall
|
||||||
|
|
||||||
for _, item := range resp.Output {
|
for _, item := range resp.Output {
|
||||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
case "reasoning":
|
case "reasoning":
|
||||||
for _, s := range item.Summary {
|
for _, s := range item.Summary {
|
||||||
if s.Type == "summary_text" && s.Text != "" {
|
if s.Type == "summary_text" && s.Text != "" {
|
||||||
contentText += s.Text
|
reasoningText += s.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "web_search_call":
|
case "web_search_call":
|
||||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
raw, _ := json.Marshal(contentText)
|
raw, _ := json.Marshal(contentText)
|
||||||
msg.Content = raw
|
msg.Content = raw
|
||||||
}
|
}
|
||||||
|
if reasoningText != "" {
|
||||||
|
msg.ReasoningContent = reasoningText
|
||||||
|
}
|
||||||
|
|
||||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||||
|
|
||||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
|||||||
return resToChatHandleFuncArgsDelta(evt, state)
|
return resToChatHandleFuncArgsDelta(evt, state)
|
||||||
case "response.reasoning_summary_text.delta":
|
case "response.reasoning_summary_text.delta":
|
||||||
return resToChatHandleReasoningDelta(evt, state)
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
|
case "response.reasoning_summary_text.done":
|
||||||
|
return nil
|
||||||
case "response.completed", "response.incomplete", "response.failed":
|
case "response.completed", "response.incomplete", "response.failed":
|
||||||
return resToChatHandleCompleted(evt, state)
|
return resToChatHandleCompleted(evt, state)
|
||||||
default:
|
default:
|
||||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
|||||||
if evt.Delta == "" {
|
if evt.Delta == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
content := evt.Delta
|
reasoning := evt.Delta
|
||||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||||
}
|
}
|
||||||
|
|
||||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
|||||||
@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
|
|||||||
|
|
||||||
// ChatMessage is a single message in the Chat Completions conversation.
|
// ChatMessage is a single message in the Chat Completions conversation.
|
||||||
type ChatMessage struct {
|
type ChatMessage struct {
|
||||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||||
Content json.RawMessage `json:"content,omitempty"`
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
|
||||||
// Legacy function calling
|
// Legacy function calling
|
||||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||||
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
|
|||||||
|
|
||||||
// ChatDelta carries incremental content in a streaming chunk.
|
// ChatDelta carries incremental content in a streaming chunk.
|
||||||
type ChatDelta struct {
|
type ChatDelta struct {
|
||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||||
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accepted 返回异步接受响应 (HTTP 202)
|
||||||
|
func Accepted(c *gin.Context, data any) {
|
||||||
|
c.JSON(http.StatusAccepted, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: "accepted",
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Error 返回错误响应
|
// Error 返回错误响应
|
||||||
func Error(c *gin.Context, statusCode int, message string) {
|
func Error(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, Response{
|
c.JSON(statusCode, Response{
|
||||||
|
|||||||
@@ -3,6 +3,28 @@ package usagestats
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelSourceRequested = "requested"
|
||||||
|
ModelSourceUpstream = "upstream"
|
||||||
|
ModelSourceMapping = "mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsValidModelSource(source string) bool {
|
||||||
|
switch source {
|
||||||
|
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeModelSource(source string) string {
|
||||||
|
if IsValidModelSource(source) {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
return ModelSourceRequested
|
||||||
|
}
|
||||||
|
|
||||||
// DashboardStats 仪表盘统计
|
// DashboardStats 仪表盘统计
|
||||||
type DashboardStats struct {
|
type DashboardStats struct {
|
||||||
// 用户统计
|
// 用户统计
|
||||||
@@ -81,6 +103,22 @@ type ModelStat struct {
|
|||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EndpointStat represents usage statistics for a single request endpoint.
|
||||||
|
type EndpointStat struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupUsageSummary represents today's and cumulative cost for a single group.
|
||||||
|
type GroupUsageSummary struct {
|
||||||
|
GroupID int64 `json:"group_id"`
|
||||||
|
TodayCost float64 `json:"today_cost"`
|
||||||
|
TotalCost float64 `json:"total_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
// GroupStat represents usage statistics for a single group
|
// GroupStat represents usage statistics for a single group
|
||||||
type GroupStat struct {
|
type GroupStat struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
@@ -96,12 +134,49 @@ type UserUsageTrendPoint struct {
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingItem represents a user spending ranking row.
|
||||||
|
type UserSpendingRankingItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||||
|
type UserSpendingRankingResponse struct {
|
||||||
|
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
|
||||||
|
type UserBreakdownItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||||
|
type UserBreakdownDimension struct {
|
||||||
|
GroupID int64 // filter by group_id (>0 to enable)
|
||||||
|
Model string // filter by model name (non-empty to enable)
|
||||||
|
ModelType string // "requested", "upstream", or "mapping"
|
||||||
|
Endpoint string // filter by endpoint value (non-empty to enable)
|
||||||
|
EndpointType string // "inbound", "upstream", or "path"
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||||
type APIKeyUsageTrendPoint struct {
|
type APIKeyUsageTrendPoint struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
@@ -163,15 +238,18 @@ type UsageLogFilters struct {
|
|||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
type UsageStats struct {
|
type UsageStats struct {
|
||||||
TotalRequests int64 `json:"total_requests"`
|
TotalRequests int64 `json:"total_requests"`
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||||
|
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
@@ -238,7 +316,9 @@ type AccountUsageSummary struct {
|
|||||||
|
|
||||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||||
type AccountUsageStatsResponse struct {
|
type AccountUsageStatsResponse struct {
|
||||||
History []AccountUsageHistory `json:"history"`
|
History []AccountUsageHistory `json:"history"`
|
||||||
Summary AccountUsageSummary `json:"summary"`
|
Summary AccountUsageSummary `json:"summary"`
|
||||||
Models []ModelStat `json:"models"`
|
Models []ModelStat `json:"models"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||||
}
|
}
|
||||||
|
|||||||
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
47
backend/internal/pkg/usagestats/usage_log_types_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package usagestats
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsValidModelSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "requested", source: ModelSourceRequested, want: true},
|
||||||
|
{name: "upstream", source: ModelSourceUpstream, want: true},
|
||||||
|
{name: "mapping", source: ModelSourceMapping, want: true},
|
||||||
|
{name: "invalid", source: "foobar", want: false},
|
||||||
|
{name: "empty", source: "", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := IsValidModelSource(tc.source); got != tc.want {
|
||||||
|
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeModelSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
|
||||||
|
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
|
||||||
|
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
|
||||||
|
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
|
||||||
|
{name: "empty falls back", source: "", want: ModelSourceRequested},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := NormalizeModelSource(tc.source); got != tc.want {
|
||||||
|
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,6 +56,7 @@ var schedulerNeutralExtraKeyPrefixes = []string{
|
|||||||
"codex_secondary_",
|
"codex_secondary_",
|
||||||
"codex_5h_",
|
"codex_5h_",
|
||||||
"codex_7d_",
|
"codex_7d_",
|
||||||
|
"passive_usage_",
|
||||||
}
|
}
|
||||||
|
|
||||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||||
@@ -397,9 +398,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||||
}
|
}
|
||||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||||
}
|
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -473,7 +474,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(dbaccount.NameContainsFold(search))
|
q = q.Where(dbaccount.NameContainsFold(search))
|
||||||
}
|
}
|
||||||
if groupID > 0 {
|
if groupID == service.AccountListGroupUngrouped {
|
||||||
|
q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups()))
|
||||||
|
} else if groupID > 0 {
|
||||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1727,8 +1730,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
|||||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||||
|
|
||||||
|
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||||
|
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||||
|
const dailyExpiredExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||||
|
const weeklyExpiredExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||||
|
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||||
|
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||||
|
const nextDailyResetAtExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN to_char((
|
||||||
|
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||||
|
CASE WHEN NOW() >= (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||||
|
THEN (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ '1 day'::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- NOW() is before today's reset point → next reset is today
|
||||||
|
ELSE (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
END
|
||||||
|
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||||
|
ELSE NULL END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||||
|
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||||
|
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||||
|
const nextWeeklyResetAtExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN to_char((
|
||||||
|
-- Compute this week's reset point in the configured timezone
|
||||||
|
-- Step 1: get today's date at reset hour in configured tz
|
||||||
|
-- Step 2: compute days forward to target weekday
|
||||||
|
-- Step 3: if same day but past reset hour, advance 7 days
|
||||||
|
CASE
|
||||||
|
WHEN (
|
||||||
|
-- days_forward = (target_day - current_day + 7) % 7
|
||||||
|
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||||
|
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||||
|
+ 7) % 7
|
||||||
|
) = 0 AND NOW() >= (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- Same weekday and past reset hour → next week
|
||||||
|
THEN (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ '7 days'::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
ELSE (
|
||||||
|
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ ((
|
||||||
|
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||||
|
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||||
|
+ 7) % 7
|
||||||
|
) || ' days')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
END
|
||||||
|
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||||
|
ELSE NULL END
|
||||||
|
)`
|
||||||
|
|
||||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||||
|
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||||
rows, err := r.sql.QueryContext(ctx,
|
rows, err := r.sql.QueryContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
@@ -1739,31 +1830,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
|||||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_daily_used',
|
'quota_daily_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
'quota_daily_start',
|
'quota_daily_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
-- 固定模式重置时更新下次重置时间
|
||||||
|
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_weekly_used',
|
'quota_weekly_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
'quota_weekly_start',
|
'quota_weekly_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
-- 固定模式重置时更新下次重置时间
|
||||||
|
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
), updated_at = NOW()
|
), updated_at = NOW()
|
||||||
WHERE id = $2 AND deleted_at IS NULL
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
@@ -1796,12 +1891,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||||
|
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx,
|
_, err := r.sql.ExecContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
COALESCE(extra, '{}'::jsonb)
|
COALESCE(extra, '{}'::jsonb)
|
||||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||||
WHERE id = $1 AND deleted_at IS NULL`,
|
WHERE id = $1 AND deleted_at IS NULL`,
|
||||||
id)
|
id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
|||||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "sync-credentials-update",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
|
account.Credentials = map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := s.repo.Update(s.ctx, account)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||||
|
s.Require().True(ok)
|
||||||
|
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AccountRepoSuite) TestDelete() {
|
func (s *AccountRepoSuite) TestDelete() {
|
||||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||||
|
|
||||||
@@ -185,6 +214,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
accType string
|
accType string
|
||||||
status string
|
status string
|
||||||
search string
|
search string
|
||||||
|
groupID int64
|
||||||
wantCount int
|
wantCount int
|
||||||
validate func(accounts []service.Account)
|
validate func(accounts []service.Account)
|
||||||
}{
|
}{
|
||||||
@@ -236,6 +266,21 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
s.Require().Contains(accounts[0].Name, "alpha")
|
s.Require().Contains(accounts[0].Name, "alpha")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "filter_by_ungrouped",
|
||||||
|
setup: func(client *dbent.Client) {
|
||||||
|
group := mustCreateGroup(s.T(), client, &service.Group{Name: "g-ungrouped"})
|
||||||
|
grouped := mustCreateAccount(s.T(), client, &service.Account{Name: "grouped-account"})
|
||||||
|
mustCreateAccount(s.T(), client, &service.Account{Name: "ungrouped-account"})
|
||||||
|
mustBindAccountToGroup(s.T(), client, grouped.ID, group.ID, 1)
|
||||||
|
},
|
||||||
|
groupID: service.AccountListGroupUngrouped,
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []service.Account) {
|
||||||
|
s.Require().Equal("ungrouped-account", accounts[0].Name)
|
||||||
|
s.Require().Empty(accounts[0].GroupIDs)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -248,7 +293,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
|
|
||||||
tt.setup(client)
|
tt.setup(client)
|
||||||
|
|
||||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(accounts, tt.wantCount)
|
s.Require().Len(accounts, tt.wantCount)
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldModelRoutingEnabled,
|
group.FieldModelRoutingEnabled,
|
||||||
group.FieldModelRouting,
|
group.FieldModelRouting,
|
||||||
group.FieldMcpXMLInject,
|
group.FieldMcpXMLInject,
|
||||||
group.FieldSimulateClaudeMaxEnabled,
|
|
||||||
group.FieldSupportedModelScopes,
|
group.FieldSupportedModelScopes,
|
||||||
group.FieldAllowMessagesDispatch,
|
group.FieldAllowMessagesDispatch,
|
||||||
group.FieldDefaultMappedModel,
|
group.FieldDefaultMappedModel,
|
||||||
@@ -646,7 +645,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.McpXMLInject,
|
MCPXMLInject: g.McpXMLInject,
|
||||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SortOrder: g.SortOrder,
|
SortOrder: g.SortOrder,
|
||||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||||
|
|||||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||||
|
type PgDumper struct {
|
||||||
|
cfg *config.DatabaseConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPgDumper creates a new PgDumper
|
||||||
|
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||||
|
return &PgDumper{cfg: &cfg.Database}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump executes pg_dump and returns a streaming reader of the output
|
||||||
|
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||||
|
args := []string{
|
||||||
|
"-h", d.cfg.Host,
|
||||||
|
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||||
|
"-U", d.cfg.User,
|
||||||
|
"-d", d.cfg.DBName,
|
||||||
|
"--no-owner",
|
||||||
|
"--no-acl",
|
||||||
|
"--clean",
|
||||||
|
"--if-exists",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||||
|
if d.cfg.Password != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||||
|
}
|
||||||
|
if d.cfg.SSLMode != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||||
|
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore executes psql to restore from a streaming reader
|
||||||
|
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||||
|
args := []string{
|
||||||
|
"-h", d.cfg.Host,
|
||||||
|
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||||
|
"-U", d.cfg.User,
|
||||||
|
"-d", d.cfg.DBName,
|
||||||
|
"--single-transaction",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||||
|
if d.cfg.Password != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||||
|
}
|
||||||
|
if d.cfg.SSLMode != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Stdin = data
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%v: %s", err, string(output))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||||
|
type cmdReadCloser struct {
|
||||||
|
io.ReadCloser
|
||||||
|
cmd *exec.Cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cmdReadCloser) Close() error {
|
||||||
|
// Close the pipe first
|
||||||
|
_ = c.ReadCloser.Close()
|
||||||
|
// Wait for the process to exit
|
||||||
|
if err := c.cmd.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
117
backend/internal/repository/backup_s3_store.go
Normal file
117
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||||
|
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||||
|
type S3BackupStore struct {
|
||||||
|
client *s3.Client
|
||||||
|
bucket string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||||
|
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||||
|
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||||
|
region := cfg.Region
|
||||||
|
if region == "" {
|
||||||
|
region = "auto" // Cloudflare R2 默认 region
|
||||||
|
}
|
||||||
|
|
||||||
|
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||||
|
awsconfig.WithRegion(region),
|
||||||
|
awsconfig.WithCredentialsProvider(
|
||||||
|
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load aws config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||||
|
if cfg.Endpoint != "" {
|
||||||
|
o.BaseEndpoint = &cfg.Endpoint
|
||||||
|
}
|
||||||
|
if cfg.ForcePathStyle {
|
||||||
|
o.UsePathStyle = true
|
||||||
|
}
|
||||||
|
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||||
|
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||||
|
})
|
||||||
|
|
||||||
|
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||||
|
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||||
|
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
Body: bytes.NewReader(data),
|
||||||
|
ContentType: &contentType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||||
|
}
|
||||||
|
return int64(len(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||||
|
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||||
|
}
|
||||||
|
return result.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||||
|
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||||
|
presignClient := s3.NewPresignClient(s.client)
|
||||||
|
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
}, s3.WithPresignExpires(expiry))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("presign url: %w", err)
|
||||||
|
}
|
||||||
|
return result.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||||
|
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user