mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
Compare commits
501 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c71b1d63e5 | ||
|
|
a07296770c | ||
|
|
8154575d70 | ||
|
|
d757df8a4b | ||
|
|
c5688fef9a | ||
|
|
19655a15f1 | ||
|
|
f345b0f595 | ||
|
|
58707f8a2a | ||
|
|
c6089ccb33 | ||
|
|
f585a15eff | ||
|
|
08e69af572 | ||
|
|
294b4bcbac | ||
|
|
67008b5d15 | ||
|
|
a29f5a4849 | ||
|
|
1b1c08f7fb | ||
|
|
0c72be0403 | ||
|
|
b4bd89b96b | ||
|
|
5bb8b2add6 | ||
|
|
93b42ccfea | ||
|
|
ff86154a03 | ||
|
|
fcee67e317 | ||
|
|
155900e62f | ||
|
|
9c514c9808 | ||
|
|
62e80c602d | ||
|
|
dbb248df52 | ||
|
|
66779f1c5f | ||
|
|
2c856b67ca | ||
|
|
bf45581104 | ||
|
|
e88b2890d1 | ||
|
|
1b5ae71d1f | ||
|
|
d4ff835bf1 | ||
|
|
e27b0adbc8 | ||
|
|
e59fa8637a | ||
|
|
58f758c816 | ||
|
|
feb6999d9a | ||
|
|
71f61bbc47 | ||
|
|
6d3ea64a35 | ||
|
|
1fca2bfab1 | ||
|
|
ce41afb756 | ||
|
|
b4a42a640d | ||
|
|
58b26cb4c8 | ||
|
|
b453c32743 | ||
|
|
3cd398b098 | ||
|
|
d3127b8eb1 | ||
|
|
6de1d0cb33 | ||
|
|
6c718578a5 | ||
|
|
0d241d52eb | ||
|
|
212eaa3a05 | ||
|
|
f3ab3fe5e2 | ||
|
|
b8c56ff940 | ||
|
|
38da737e6c | ||
|
|
1b2ea7a1df | ||
|
|
a9e5fc8539 | ||
|
|
9b213115e7 | ||
|
|
5534347328 | ||
|
|
2355029dc1 | ||
|
|
8d25335b01 | ||
|
|
c0b5900a37 | ||
|
|
35a9290528 | ||
|
|
c9145ad4d8 | ||
|
|
3851628a43 | ||
|
|
d72ac92694 | ||
|
|
2555951be4 | ||
|
|
669bff78c4 | ||
|
|
c90d1f2527 | ||
|
|
40cebc250f | ||
|
|
ddd495fb48 | ||
|
|
58f2044637 | ||
|
|
dfe3fdc1cc | ||
|
|
705131e172 | ||
|
|
88759407c7 | ||
|
|
6c99cc611c | ||
|
|
3457bcbfcd | ||
|
|
eb385457b2 | ||
|
|
4ea8b4cb4f | ||
|
|
91bdcf8994 | ||
|
|
8d03c52e15 | ||
|
|
0fbc9a44d3 | ||
|
|
632035aabd | ||
|
|
a51e0047b7 | ||
|
|
726730bb0e | ||
|
|
faff1771c4 | ||
|
|
b06cd06ec1 | ||
|
|
95751d8009 | ||
|
|
14e565a004 | ||
|
|
ce694701a9 | ||
|
|
12d03e4030 | ||
|
|
0b1ce6be8f | ||
|
|
28a6adaaa4 | ||
|
|
36990a0514 | ||
|
|
ebac0dc628 | ||
|
|
29d58f2414 | ||
|
|
dca0054e93 | ||
|
|
983fe58959 | ||
|
|
91c9b8d062 | ||
|
|
b384570de3 | ||
|
|
0507852a34 | ||
|
|
7b6ff135fb | ||
|
|
bf24de88ed | ||
|
|
ff6d4ab39a | ||
|
|
66fde7a2e6 | ||
|
|
e8efaa4cd9 | ||
|
|
00947d6492 | ||
|
|
cf70fb1b4e | ||
|
|
ef1a992cf0 | ||
|
|
1f6a73f0db | ||
|
|
f2e596f6ec | ||
|
|
b155bc564b | ||
|
|
055c48ab33 | ||
|
|
6663e1eda6 | ||
|
|
649afef512 | ||
|
|
4514f3fc11 | ||
|
|
095bef9554 | ||
|
|
83a16dec19 | ||
|
|
820c531814 | ||
|
|
1727b8df3b | ||
|
|
a025a15f5d | ||
|
|
72e5876c64 | ||
|
|
aeed2eb9ad | ||
|
|
46bc5ca73b | ||
|
|
0b3feb9d4c | ||
|
|
ca8692c747 | ||
|
|
318aa5e0d3 | ||
|
|
1dfd974432 | ||
|
|
cc396f59cf | ||
|
|
aa8b9cc508 | ||
|
|
6a2cf09ee0 | ||
|
|
c6fd88116b | ||
|
|
8f0dbdeaba | ||
|
|
007c09b84e | ||
|
|
73f3c068ef | ||
|
|
9a92fa4a60 | ||
|
|
576af710be | ||
|
|
b5642bd068 | ||
|
|
128f322252 | ||
|
|
17d7e57a2e | ||
|
|
50288e6b01 | ||
|
|
ab3e44e4bd | ||
|
|
61607990c8 | ||
|
|
b65275235f | ||
|
|
e298a71834 | ||
|
|
3f6fa1e3db | ||
|
|
f2c2abe628 | ||
|
|
ff5b467fbe | ||
|
|
8c10941142 | ||
|
|
f5764d8dc6 | ||
|
|
81ca4f12dd | ||
|
|
941c469ab9 | ||
|
|
8fcd819e6f | ||
|
|
9abdaed20c | ||
|
|
eb94342f78 | ||
|
|
d563eb2336 | ||
|
|
3ee6f085db | ||
|
|
7cca69a136 | ||
|
|
093a5a260e | ||
|
|
b6d46fd52f | ||
|
|
2c072c0ed6 | ||
|
|
1f39bf8a78 | ||
|
|
fdd8499ffc | ||
|
|
9398ea7af5 | ||
|
|
29dce1a59c | ||
|
|
c729ee425f | ||
|
|
c489f23810 | ||
|
|
47a544230a | ||
|
|
c13c81f09d | ||
|
|
20544a4447 | ||
|
|
b688ebeefa | ||
|
|
1854050df3 | ||
|
|
c7f4a649df | ||
|
|
ef5c8e6839 | ||
|
|
d571f300e5 | ||
|
|
ce96527dd9 | ||
|
|
f8b8b53985 | ||
|
|
b20e142249 | ||
|
|
7c6dc9dda8 | ||
|
|
5875571215 | ||
|
|
975e6b1563 | ||
|
|
f6fd7c83e3 | ||
|
|
c2965c0fb0 | ||
|
|
fdad55956e | ||
|
|
bb399e56b0 | ||
|
|
fa68cbad1b | ||
|
|
995ef1348a | ||
|
|
0f03393010 | ||
|
|
4b1ffc23f5 | ||
|
|
c7137dffa8 | ||
|
|
5a3375ce52 | ||
|
|
8e834fd9f5 | ||
|
|
02046744eb | ||
|
|
68d7ec9155 | ||
|
|
7537dce0f0 | ||
|
|
5f41b74707 | ||
|
|
25d961d4e0 | ||
|
|
08c4e514f8 | ||
|
|
91b1d812ce | ||
|
|
1f05d9f79d | ||
|
|
9f8cffe887 | ||
|
|
995bee143a | ||
|
|
f10e56be7e | ||
|
|
2f8e10db46 | ||
|
|
5418e15e63 | ||
|
|
bcf84cc153 | ||
|
|
ce8520c9e6 | ||
|
|
0b3928c33e | ||
|
|
73d72651b4 | ||
|
|
adbedd488c | ||
|
|
13b72f6bc2 | ||
|
|
c5aa96a3aa | ||
|
|
d927c0e45f | ||
|
|
31660c4c5f | ||
|
|
4321adab71 | ||
|
|
68f151f5c0 | ||
|
|
ecad083ffc | ||
|
|
fee43e8474 | ||
|
|
4838ab74b3 | ||
|
|
fef9259aaa | ||
|
|
ad7c10727a | ||
|
|
ccd42c1d1a | ||
|
|
bd8eadb75b | ||
|
|
70a9d0d3a2 | ||
|
|
7cd3824863 | ||
|
|
db9021f9c1 | ||
|
|
a2418c6040 | ||
|
|
1fb29d59b7 | ||
|
|
8c4a217f03 | ||
|
|
bda7c39e55 | ||
|
|
3583283ebb | ||
|
|
4feacf2213 | ||
|
|
73eb731881 | ||
|
|
186e36752d | ||
|
|
421728a985 | ||
|
|
39a5701184 | ||
|
|
27948c777e | ||
|
|
c64ed46d05 | ||
|
|
c64465ff7e | ||
|
|
095200bd16 | ||
|
|
2c667a159c | ||
|
|
bac408044f | ||
|
|
4edcfe1f7c | ||
|
|
9259dcb6f5 | ||
|
|
7ef933c7cf | ||
|
|
7d312822c1 | ||
|
|
1b3e5c6ea6 | ||
|
|
efe8401e92 | ||
|
|
0b845c2532 | ||
|
|
fe60412a17 | ||
|
|
5c39e6f2fb | ||
|
|
a225a241d7 | ||
|
|
553a486d17 | ||
|
|
c73374a221 | ||
|
|
94e26dee4f | ||
|
|
4617ef2bb8 | ||
|
|
8afa8c1091 | ||
|
|
578608d301 | ||
|
|
0d45d8669e | ||
|
|
73708da60d | ||
|
|
c810cad7c8 | ||
|
|
94bba415b1 | ||
|
|
4f7629a4cb | ||
|
|
4015f31f28 | ||
|
|
9dccbe1b07 | ||
|
|
9a88df7f28 | ||
|
|
a47f622e7e | ||
|
|
3529148455 | ||
|
|
01d8286bd9 | ||
|
|
21b6f2d593 | ||
|
|
528ff5d28c | ||
|
|
ba7d2aecbb | ||
|
|
0236b97d49 | ||
|
|
26f6b1eeff | ||
|
|
dc447ccebe | ||
|
|
7ec29638f4 | ||
|
|
4c9562af20 | ||
|
|
71942fd322 | ||
|
|
550b979ac5 | ||
|
|
3878a5a46f | ||
|
|
e443a6a1ea | ||
|
|
963494ec6f | ||
|
|
42d73118fd | ||
|
|
f2f819d70f | ||
|
|
525cdb8830 | ||
|
|
a6764e82f2 | ||
|
|
1de18b89dd | ||
|
|
882518c111 | ||
|
|
8027531d07 | ||
|
|
30706355a4 | ||
|
|
dfe99507b8 | ||
|
|
c1717c9a6c | ||
|
|
1fd1a58a7a | ||
|
|
fad07507be | ||
|
|
a20c211162 | ||
|
|
9f6ab6b817 | ||
|
|
bf3d6c0e6e | ||
|
|
241023f3fc | ||
|
|
1292c44b41 | ||
|
|
b4fce47049 | ||
|
|
e7780cd8c8 | ||
|
|
af96c8ea53 | ||
|
|
7d26b81075 | ||
|
|
b8ada63ac3 | ||
|
|
cfaac12af1 | ||
|
|
6028efd26c | ||
|
|
62a566ef2c | ||
|
|
94419f434c | ||
|
|
21f349c032 | ||
|
|
28e36f7925 | ||
|
|
6c02076333 | ||
|
|
7414bdf0e3 | ||
|
|
e6326b2929 | ||
|
|
17cdcebd04 | ||
|
|
a14babdc73 | ||
|
|
aadc6a763a | ||
|
|
f16af8bf88 | ||
|
|
5ceaef4500 | ||
|
|
1ac7219a92 | ||
|
|
d4cc9871c4 | ||
|
|
961c30e7c0 | ||
|
|
13e85b3147 | ||
|
|
50a3c7fa0b | ||
|
|
bd9d2671d7 | ||
|
|
62b40636e0 | ||
|
|
eeff451bc5 | ||
|
|
56fcb20f94 | ||
|
|
7134266acf | ||
|
|
2e4ac88ad9 | ||
|
|
51547fa216 | ||
|
|
2005fc97a8 | ||
|
|
0772d9250e | ||
|
|
aa6047c460 | ||
|
|
045cba78b4 | ||
|
|
8989d0d4b6 | ||
|
|
c521117b99 | ||
|
|
e0f52a8ab8 | ||
|
|
6c23fadf7e | ||
|
|
869952d113 | ||
|
|
07ab051ee4 | ||
|
|
f2d98fc0c7 | ||
|
|
2b41cec840 | ||
|
|
6cf77040e7 | ||
|
|
20b70bc5fd | ||
|
|
4905e7193a | ||
|
|
9c1f4b8e72 | ||
|
|
9857c17631 | ||
|
|
7e34bb946f | ||
|
|
47b748851b | ||
|
|
a6f99cf534 | ||
|
|
a120a6bc32 | ||
|
|
d557d1a190 | ||
|
|
e0286e5085 | ||
|
|
4b41e898a4 | ||
|
|
668e164793 | ||
|
|
fa2e6188d0 | ||
|
|
7fde9ebbc2 | ||
|
|
aef7c3b9bb | ||
|
|
a0b76bd608 | ||
|
|
c1fab7f8d8 | ||
|
|
f42c8f2abe | ||
|
|
aa5846b282 | ||
|
|
594a0ade38 | ||
|
|
d45cc23171 | ||
|
|
d795734352 | ||
|
|
4da9fdd1d5 | ||
|
|
6b218caa21 | ||
|
|
5c138007d0 | ||
|
|
1acfc46f46 | ||
|
|
fbffb08aae | ||
|
|
8640a62319 | ||
|
|
fa782e70a4 | ||
|
|
afd72abc6e | ||
|
|
71f72e167e | ||
|
|
6595c7601e | ||
|
|
67c0506290 | ||
|
|
6447be4534 | ||
|
|
3741617ebd | ||
|
|
ab4e8b2cf0 | ||
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
6826149a8f | ||
|
|
c9debc50b1 |
@@ -61,6 +61,9 @@ temp/
|
||||
deploy/install.sh
|
||||
deploy/sub2api.service
|
||||
deploy/sub2api-sudoers
|
||||
deploy/data/
|
||||
deploy/postgres_data/
|
||||
deploy/redis_data/
|
||||
|
||||
# GoReleaser
|
||||
.goreleaser.yaml
|
||||
|
||||
7
.gitattributes
vendored
7
.gitattributes
vendored
@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
|
||||
# Go 源代码文件
|
||||
*.go text eol=lf
|
||||
|
||||
# 前端 源代码文件
|
||||
*.ts text eol=lf
|
||||
*.tsx text eol=lf
|
||||
*.js text eol=lf
|
||||
*.jsx text eol=lf
|
||||
*.vue text eol=lf
|
||||
|
||||
# Shell 脚本
|
||||
*.sh text eol=lf
|
||||
|
||||
|
||||
14
.github/audit-exceptions.yml
vendored
14
.github/audit-exceptions.yml
vendored
@@ -14,3 +14,17 @@ exceptions:
|
||||
mitigation: "Load only on export; restrict export permissions and data scope"
|
||||
expires_on: "2026-04-05"
|
||||
owner: "security@your-domain"
|
||||
- package: lodash
|
||||
advisory: "GHSA-r5fr-rjxr-66jc"
|
||||
severity: high
|
||||
reason: "lodash _.template not used with untrusted input; only internal admin UI templates"
|
||||
mitigation: "No user-controlled template strings; plan to migrate to lodash-es tree-shaken imports"
|
||||
expires_on: "2026-07-02"
|
||||
owner: "security@your-domain"
|
||||
- package: lodash-es
|
||||
advisory: "GHSA-r5fr-rjxr-66jc"
|
||||
severity: high
|
||||
reason: "lodash-es _.template not used with untrusted input; only internal admin UI templates"
|
||||
mitigation: "No user-controlled template strings; plan to migrate to native JS alternatives"
|
||||
expires_on: "2026-07-02"
|
||||
owner: "security@your-domain"
|
||||
|
||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -271,3 +271,36 @@ jobs:
|
||||
parse_mode: "Markdown",
|
||||
disable_web_page_preview: true
|
||||
}')"
|
||||
|
||||
sync-version-file:
|
||||
needs: [release]
|
||||
if: ${{ needs.release.result == 'success' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout default branch
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Sync VERSION file to released tag
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
VERSION=${{ github.event.inputs.tag }}
|
||||
VERSION=${VERSION#v}
|
||||
else
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
fi
|
||||
|
||||
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
|
||||
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
|
||||
echo "VERSION file already matches $VERSION"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "$VERSION" > backend/cmd/server/VERSION
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
|
||||
git push origin HEAD:${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -47,6 +47,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
|
||||
@@ -63,6 +63,8 @@ dockers:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -76,6 +78,8 @@ dockers:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -89,6 +93,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
@@ -102,6 +108,8 @@ dockers:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
extra_files:
|
||||
- deploy/docker-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
|
||||
31
Dockerfile
31
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,21 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
@@ -102,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||
|
||||
# Switch to non-root user
|
||||
USER sub2api
|
||||
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
# Expose port (can be overridden by SERVER_PORT env var)
|
||||
EXPOSE 8080
|
||||
@@ -112,5 +132,6 @@ EXPOSE 8080
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
# Run the application
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/sub2api"]
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||
# restore work in the runtime container without requiring Docker socket access.
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
USER sub2api
|
||||
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
|
||||
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/sub2api"]
|
||||
|
||||
76
README.md
76
README.md
@@ -8,27 +8,31 @@
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||
|
||||
English | [中文](README_CN.md)
|
||||
English | [中文](README_CN.md) | [日本語](README_JA.md)
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
|
||||
|
||||
---
|
||||
|
||||
## Demo
|
||||
|
||||
Try Sub2API online: **https://demo.sub2api.org/**
|
||||
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||
|
||||
| Email | Password |
|
||||
|-------|----------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## Overview
|
||||
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
|
||||
## Features
|
||||
|
||||
@@ -39,6 +43,29 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||
- **Rate Limiting** - Configurable request and token rate limits
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||
|
||||
## Don't Want to Self-Host?
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="150"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=sub2api"><img src="assets/partners/logos/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=sub2api">this link</a> and enter the "sub2api" promo code during first recharge to get 10% off.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
Community projects that extend or integrate with Sub2API:
|
||||
|
||||
| Project | Description | Features |
|
||||
|---------|-------------|----------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||
|
||||
## Tech Stack
|
||||
|
||||
@@ -51,10 +78,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
## Nginx Reverse Proxy Note
|
||||
|
||||
- Dependency Security: `docs/dependency-security.md`
|
||||
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
|
||||
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
|
||||
|
||||
---
|
||||
|
||||
@@ -150,10 +182,10 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f sub2api
|
||||
docker compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
@@ -217,16 +249,16 @@ mkdir -p data postgres_data redis_data
|
||||
|
||||
# 5. Start all services
|
||||
# Option A: Local directory version (recommended - easy migration)
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
|
||||
# Option B: Named volumes version (simple setup)
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
|
||||
# 6. Check status
|
||||
docker-compose -f docker-compose.local.yml ps
|
||||
docker compose -f docker-compose.local.yml ps
|
||||
|
||||
# 7. View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker compose -f docker-compose.local.yml logs -f sub2api
|
||||
```
|
||||
|
||||
#### Deployment Versions
|
||||
@@ -244,15 +276,15 @@ Open `http://YOUR_SERVER_IP:8080` in your browser.
|
||||
|
||||
If admin password was auto-generated, find it in logs:
|
||||
```bash
|
||||
docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
docker compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
```
|
||||
|
||||
#### Upgrade
|
||||
|
||||
```bash
|
||||
# Pull latest image and recreate container
|
||||
docker-compose -f docker-compose.local.yml pull
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker compose -f docker-compose.local.yml pull
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
```
|
||||
|
||||
#### Easy Migration (Local Directory Version)
|
||||
@@ -261,7 +293,7 @@ When using `docker-compose.local.yml`, migrate to a new server easily:
|
||||
|
||||
```bash
|
||||
# On source server
|
||||
docker-compose -f docker-compose.local.yml down
|
||||
docker compose -f docker-compose.local.yml down
|
||||
cd ..
|
||||
tar czf sub2api-complete.tar.gz sub2api-deploy/
|
||||
|
||||
@@ -271,23 +303,23 @@ scp sub2api-complete.tar.gz user@new-server:/path/
|
||||
# On new server
|
||||
tar xzf sub2api-complete.tar.gz
|
||||
cd sub2api-deploy/
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
```
|
||||
|
||||
#### Useful Commands
|
||||
|
||||
```bash
|
||||
# Stop all services
|
||||
docker-compose -f docker-compose.local.yml down
|
||||
docker compose -f docker-compose.local.yml down
|
||||
|
||||
# Restart
|
||||
docker-compose -f docker-compose.local.yml restart
|
||||
docker compose -f docker-compose.local.yml restart
|
||||
|
||||
# View all logs
|
||||
docker-compose -f docker-compose.local.yml logs -f
|
||||
docker compose -f docker-compose.local.yml logs -f
|
||||
|
||||
# Remove all data (caution!)
|
||||
docker-compose -f docker-compose.local.yml down
|
||||
docker compose -f docker-compose.local.yml down
|
||||
rm -rf data/ postgres_data/ redis_data/
|
||||
```
|
||||
|
||||
|
||||
79
README_CN.md
79
README_CN.md
@@ -8,27 +8,30 @@
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**AI API 网关平台 - 订阅配额分发管理**
|
||||
|
||||
[English](README.md) | 中文
|
||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
|
||||
---
|
||||
|
||||
## 在线体验
|
||||
|
||||
体验地址:**https://v2.pincc.ai/**
|
||||
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
演示账号(共享演示环境;自建部署不会自动创建该账号):
|
||||
|
||||
| 邮箱 | 密码 |
|
||||
|------|------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## 项目概述
|
||||
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
|
||||
## 核心功能
|
||||
|
||||
@@ -39,6 +42,29 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
- **并发控制** - 用户级和账号级并发限制
|
||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||
|
||||
## 不想自建?试试官方中转
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="150"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=sub2api"><img src="assets/partners/logos/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>感谢 PackyCode 赞助了本项目!PackyCode 是一家稳定、高效的API中转服务商,提供 Claude Code、Codex、Gemini 等多种中转服务。PackyCode 为本软件的用户提供了特别优惠,使用<a href="https://www.packyapi.com/register?aff=sub2api">此链接</a>注册并在充值时填写"sub2api"优惠码,首次充值可以享受9折优惠!</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
围绕 Sub2API 的社区扩展与集成项目:
|
||||
|
||||
| 项目 | 说明 | 功能 |
|
||||
|------|------|------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -51,17 +77,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
|
||||
---
|
||||
|
||||
## 文档
|
||||
## Nginx 反向代理注意事项
|
||||
|
||||
- 依赖安全:`docs/dependency-security.md`
|
||||
通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
|
||||
|
||||
---
|
||||
|
||||
## OpenAI Responses 兼容注意事项
|
||||
|
||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||
|
||||
## 部署方式
|
||||
|
||||
### 方式一:脚本安装(推荐)
|
||||
@@ -154,10 +181,10 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose logs -f sub2api
|
||||
docker compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
@@ -221,16 +248,16 @@ mkdir -p data postgres_data redis_data
|
||||
|
||||
# 5. 启动所有服务
|
||||
# 选项 A:本地目录版(推荐 - 易于迁移)
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
|
||||
# 选项 B:命名卷版(简单设置)
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
|
||||
# 6. 查看状态
|
||||
docker-compose -f docker-compose.local.yml ps
|
||||
docker compose -f docker-compose.local.yml ps
|
||||
|
||||
# 7. 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker compose -f docker-compose.local.yml logs -f sub2api
|
||||
```
|
||||
|
||||
#### 部署版本对比
|
||||
@@ -260,15 +287,15 @@ docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
|
||||
如果管理员密码是自动生成的,在日志中查找:
|
||||
```bash
|
||||
docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
docker compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
```
|
||||
|
||||
#### 升级
|
||||
|
||||
```bash
|
||||
# 拉取最新镜像并重建容器
|
||||
docker-compose -f docker-compose.local.yml pull
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker compose -f docker-compose.local.yml pull
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
```
|
||||
|
||||
#### 轻松迁移(本地目录版)
|
||||
@@ -277,7 +304,7 @@ docker-compose -f docker-compose.local.yml up -d
|
||||
|
||||
```bash
|
||||
# 源服务器
|
||||
docker-compose -f docker-compose.local.yml down
|
||||
docker compose -f docker-compose.local.yml down
|
||||
cd ..
|
||||
tar czf sub2api-complete.tar.gz sub2api-deploy/
|
||||
|
||||
@@ -287,23 +314,23 @@ scp sub2api-complete.tar.gz user@new-server:/path/
|
||||
# 新服务器
|
||||
tar xzf sub2api-complete.tar.gz
|
||||
cd sub2api-deploy/
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
```
|
||||
|
||||
#### 常用命令
|
||||
|
||||
```bash
|
||||
# 停止所有服务
|
||||
docker-compose -f docker-compose.local.yml down
|
||||
docker compose -f docker-compose.local.yml down
|
||||
|
||||
# 重启
|
||||
docker-compose -f docker-compose.local.yml restart
|
||||
docker compose -f docker-compose.local.yml restart
|
||||
|
||||
# 查看所有日志
|
||||
docker-compose -f docker-compose.local.yml logs -f
|
||||
docker compose -f docker-compose.local.yml logs -f
|
||||
|
||||
# 删除所有数据(谨慎!)
|
||||
docker-compose -f docker-compose.local.yml down
|
||||
docker compose -f docker-compose.local.yml down
|
||||
rm -rf data/ postgres_data/ redis_data/
|
||||
```
|
||||
|
||||
|
||||
589
README_JA.md
Normal file
589
README_JA.md
Normal file
@@ -0,0 +1,589 @@
|
||||
# Sub2API
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
**サブスクリプションクォータ配分のための AI API ゲートウェイプラットフォーム**
|
||||
|
||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||
|
||||
</div>
|
||||
|
||||
> **Sub2API が公式に使用しているドメインは `sub2api.org` と `pincc.ai` のみです。Sub2API の名称を使用している他のウェブサイトは、サードパーティによるデプロイやサービスであり、本プロジェクトとは一切関係がありません。ご利用の際はご自身で確認・判断をお願いします。**
|
||||
|
||||
---
|
||||
|
||||
## デモ
|
||||
|
||||
Sub2API をオンラインでお試しください: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
|
||||
|
||||
デモ用認証情報(共有デモ環境です。セルフホスト環境では**自動作成されません**):
|
||||
|
||||
| メールアドレス | パスワード |
|
||||
|-------|----------|
|
||||
| admin@sub2api.org | admin123 |
|
||||
|
||||
## 概要
|
||||
|
||||
Sub2API は、AI 製品のサブスクリプションから API クォータを配分・管理するために設計された AI API ゲートウェイプラットフォームです。ユーザーはプラットフォームが生成した API キーを通じて上流の AI サービスにアクセスでき、プラットフォームは認証、課金、負荷分散、リクエスト転送を処理します。
|
||||
|
||||
## 機能
|
||||
|
||||
- **マルチアカウント管理** - 複数の上流アカウントタイプ(OAuth、APIキー)をサポート
|
||||
- **APIキー配布** - ユーザー向けの APIキーの生成と管理
|
||||
- **精密な課金** - トークンレベルの使用量追跡とコスト計算
|
||||
- **スマートスケジューリング** - スティッキーセッション付きのインテリジェントなアカウント選択
|
||||
- **同時実行制御** - ユーザーごと・アカウントごとの同時実行数制限
|
||||
- **レート制限** - 設定可能なリクエスト数およびトークンレート制限
|
||||
- **管理ダッシュボード** - 監視・管理のための Web インターフェース
|
||||
- **外部システム連携** - 外部システム(決済、チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能
|
||||
|
||||
## セルフホストが不要な方へ
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="150"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> は Sub2API 上に構築された公式リレーサービスで、Claude Code、Codex、Gemini などの人気モデルへの安定したアクセスを提供します。デプロイやメンテナンスは不要で、すぐにご利用いただけます。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=sub2api"><img src="assets/partners/logos/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>PackyCode のご支援に感謝します!PackyCode は Claude Code、Codex、Gemini などのリレーサービスを提供する信頼性の高い API 中継プラットフォームです。本ソフト利用者向けに特別割引があります:<a href="https://www.packyapi.com/register?aff=sub2api">このリンク</a>で登録し、チャージ時に「sub2api」クーポンを入力すると 10% オフになります。</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
Sub2API を拡張・統合するコミュニティプロジェクト:
|
||||
|
||||
| プロジェクト | 説明 | 機能 |
|
||||
|---------|-------------|----------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | セルフサービス決済システム | セルフサービスによるチャージおよびサブスクリプション購入。YiPay プロトコル、WeChat Pay、Alipay、Stripe 対応。iframe での埋め込み可能 |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | モバイル管理コンソール | ユーザー管理、アカウント管理、監視ダッシュボード、マルチバックエンド切り替えが可能なクロスプラットフォームアプリ(iOS/Android/Web)。Expo + React Native で構築 |
|
||||
|
||||
## 技術スタック
|
||||
|
||||
| コンポーネント | 技術 |
|
||||
|-----------|------------|
|
||||
| バックエンド | Go 1.25.7, Gin, Ent |
|
||||
| フロントエンド | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| データベース | PostgreSQL 15+ |
|
||||
| キャッシュ/キュー | Redis 7+ |
|
||||
|
||||
---
|
||||
|
||||
## Nginx リバースプロキシに関する注意
|
||||
|
||||
Sub2API(または CRS)を Nginx でリバースプロキシし、Codex CLI と組み合わせて使用する場合、Nginx の `http` ブロックに以下の設定を追加してください:
|
||||
|
||||
```nginx
|
||||
underscores_in_headers on;
|
||||
```
|
||||
|
||||
Nginx はデフォルトでアンダースコアを含むヘッダー(例: `session_id`)を破棄するため、マルチアカウント構成でのスティッキーセッションルーティングに支障をきたします。
|
||||
|
||||
---
|
||||
|
||||
## デプロイ
|
||||
|
||||
### 方法1: スクリプトによるインストール(推奨)
|
||||
|
||||
GitHub Releases からビルド済みバイナリをダウンロードするワンクリックインストールスクリプトです。
|
||||
|
||||
#### 前提条件
|
||||
|
||||
- Linux サーバー(amd64 または arm64)
|
||||
- PostgreSQL 15+(インストール済みかつ稼働中)
|
||||
- Redis 7+(インストール済みかつ稼働中)
|
||||
- root 権限
|
||||
|
||||
#### インストール手順
|
||||
|
||||
```bash
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
スクリプトは以下を実行します:
|
||||
1. システムアーキテクチャの検出
|
||||
2. 最新リリースのダウンロード
|
||||
3. バイナリを `/opt/sub2api` にインストール
|
||||
4. systemd サービスの作成
|
||||
5. システムユーザーと権限の設定
|
||||
|
||||
#### インストール後の作業
|
||||
|
||||
```bash
|
||||
# 1. サービスを起動
|
||||
sudo systemctl start sub2api
|
||||
|
||||
# 2. 起動時の自動起動を有効化
|
||||
sudo systemctl enable sub2api
|
||||
|
||||
# 3. ブラウザでセットアップウィザードを開く
|
||||
# http://YOUR_SERVER_IP:8080
|
||||
```
|
||||
|
||||
セットアップウィザードでは以下の設定を行います:
|
||||
- データベース設定
|
||||
- Redis 設定
|
||||
- 管理者アカウントの作成
|
||||
|
||||
#### アップグレード
|
||||
|
||||
**管理ダッシュボード**の左上にある**アップデートを確認**ボタンをクリックすることで、ダッシュボードから直接アップグレードできます。
|
||||
|
||||
Web インターフェースでは以下が可能です:
|
||||
- 新しいバージョンの自動確認
|
||||
- ワンクリックでのアップデートのダウンロードと適用
|
||||
- 必要に応じたロールバック
|
||||
|
||||
#### よく使うコマンド
|
||||
|
||||
```bash
|
||||
# ステータスを確認
|
||||
sudo systemctl status sub2api
|
||||
|
||||
# ログを表示
|
||||
sudo journalctl -u sub2api -f
|
||||
|
||||
# サービスを再起動
|
||||
sudo systemctl restart sub2api
|
||||
|
||||
# アンインストール
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 方法2: Docker Compose(推奨)
|
||||
|
||||
PostgreSQL と Redis のコンテナを含む Docker Compose でデプロイします。
|
||||
|
||||
#### 前提条件
|
||||
|
||||
- Docker 20.10+
|
||||
- Docker Compose v2+
|
||||
|
||||
#### クイックスタート(ワンクリックデプロイ)
|
||||
|
||||
自動デプロイスクリプトを使用して簡単にセットアップできます:
|
||||
|
||||
```bash
|
||||
# デプロイ用ディレクトリを作成
|
||||
mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
|
||||
# デプロイ準備スクリプトをダウンロードして実行
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# サービスを起動
|
||||
docker compose up -d
|
||||
|
||||
# ログを表示
|
||||
docker compose logs -f sub2api
|
||||
```
|
||||
|
||||
**スクリプトの動作内容:**
|
||||
- `docker-compose.local.yml`(`docker-compose.yml` として保存)と `.env.example` をダウンロード
|
||||
- セキュアな認証情報(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)を自動生成
|
||||
- 自動生成されたシークレットで `.env` ファイルを作成
|
||||
- データディレクトリを作成(バックアップ・移行が容易なローカルディレクトリを使用)
|
||||
- 生成された認証情報を参照用に表示
|
||||
|
||||
#### 手動デプロイ
|
||||
|
||||
手動でセットアップする場合:
|
||||
|
||||
```bash
|
||||
# 1. リポジトリをクローン
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api/deploy
|
||||
|
||||
# 2. 環境設定ファイルをコピー
|
||||
cp .env.example .env
|
||||
|
||||
# 3. 設定を編集(セキュアなパスワードを生成)
|
||||
nano .env
|
||||
```
|
||||
|
||||
**`.env` の必須設定:**
|
||||
|
||||
```bash
|
||||
# PostgreSQL パスワード(必須)
|
||||
POSTGRES_PASSWORD=your_secure_password_here
|
||||
|
||||
# JWT シークレット(推奨 - 再起動後もユーザーのログイン状態を保持)
|
||||
JWT_SECRET=your_jwt_secret_here
|
||||
|
||||
# TOTP 暗号化キー(推奨 - 再起動後も二要素認証を維持)
|
||||
TOTP_ENCRYPTION_KEY=your_totp_key_here
|
||||
|
||||
# オプション: 管理者アカウント
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=your_admin_password
|
||||
|
||||
# オプション: カスタムポート
|
||||
SERVER_PORT=8080
|
||||
```
|
||||
|
||||
**セキュアなシークレットの生成方法:**
|
||||
```bash
|
||||
# JWT_SECRET を生成
|
||||
openssl rand -hex 32
|
||||
|
||||
# TOTP_ENCRYPTION_KEY を生成
|
||||
openssl rand -hex 32
|
||||
|
||||
# POSTGRES_PASSWORD を生成
|
||||
openssl rand -hex 32
|
||||
```
|
||||
|
||||
```bash
|
||||
# 4. データディレクトリを作成(ローカルバージョンの場合)
|
||||
mkdir -p data postgres_data redis_data
|
||||
|
||||
# 5. すべてのサービスを起動
|
||||
# オプション A: ローカルディレクトリバージョン(推奨 - 移行が容易)
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
|
||||
# オプション B: 名前付きボリュームバージョン(シンプルなセットアップ)
|
||||
docker compose up -d
|
||||
|
||||
# 6. ステータスを確認
|
||||
docker compose -f docker-compose.local.yml ps
|
||||
|
||||
# 7. ログを表示
|
||||
docker compose -f docker-compose.local.yml logs -f sub2api
|
||||
```
|
||||
|
||||
#### デプロイバージョン
|
||||
|
||||
| バージョン | データストレージ | 移行 | 推奨用途 |
|
||||
|---------|-------------|-----------|----------|
|
||||
| **docker-compose.local.yml** | ローカルディレクトリ | ✅ 容易(ディレクトリ全体を tar) | 本番環境、頻繁なバックアップ |
|
||||
| **docker-compose.yml** | 名前付きボリューム | ⚠️ docker コマンドが必要 | シンプルなセットアップ |
|
||||
|
||||
**推奨:** データ管理が容易な `docker-compose.local.yml`(スクリプトによるデプロイ)を使用してください。
|
||||
|
||||
#### アクセス
|
||||
|
||||
ブラウザで `http://YOUR_SERVER_IP:8080` を開いてください。
|
||||
|
||||
管理者パスワードが自動生成された場合は、ログで確認できます:
|
||||
```bash
|
||||
docker compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
```
|
||||
|
||||
#### アップグレード
|
||||
|
||||
```bash
|
||||
# 最新イメージをプルしてコンテナを再作成
|
||||
docker compose -f docker-compose.local.yml pull
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
```
|
||||
|
||||
#### 簡単な移行(ローカルディレクトリバージョン)
|
||||
|
||||
`docker-compose.local.yml` を使用している場合、新しいサーバーへの移行が簡単です:
|
||||
|
||||
```bash
|
||||
# 移行元サーバーにて
|
||||
docker compose -f docker-compose.local.yml down
|
||||
cd ..
|
||||
tar czf sub2api-complete.tar.gz sub2api-deploy/
|
||||
|
||||
# 新しいサーバーに転送
|
||||
scp sub2api-complete.tar.gz user@new-server:/path/
|
||||
|
||||
# 移行先サーバーにて
|
||||
tar xzf sub2api-complete.tar.gz
|
||||
cd sub2api-deploy/
|
||||
docker compose -f docker-compose.local.yml up -d
|
||||
```
|
||||
|
||||
#### よく使うコマンド
|
||||
|
||||
```bash
|
||||
# すべてのサービスを停止
|
||||
docker compose -f docker-compose.local.yml down
|
||||
|
||||
# 再起動
|
||||
docker compose -f docker-compose.local.yml restart
|
||||
|
||||
# すべてのログを表示
|
||||
docker compose -f docker-compose.local.yml logs -f
|
||||
|
||||
# すべてのデータを削除(注意!)
|
||||
docker compose -f docker-compose.local.yml down
|
||||
rm -rf data/ postgres_data/ redis_data/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 方法3: ソースからビルド
|
||||
|
||||
開発やカスタマイズのためにソースコードからビルドして実行します。
|
||||
|
||||
#### 前提条件
|
||||
|
||||
- Go 1.21+
|
||||
- Node.js 18+
|
||||
- PostgreSQL 15+
|
||||
- Redis 7+
|
||||
|
||||
#### ビルド手順
|
||||
|
||||
```bash
|
||||
# 1. リポジトリをクローン
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. pnpm をインストール(未インストールの場合)
|
||||
npm install -g pnpm
|
||||
|
||||
# 3. フロントエンドをビルド
|
||||
cd frontend
|
||||
pnpm install
|
||||
pnpm run build
|
||||
# 出力先: ../backend/internal/web/dist/
|
||||
|
||||
# 4. フロントエンドを組み込んだバックエンドをビルド
|
||||
cd ../backend
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. 設定ファイルを作成
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. 設定を編集
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **注意:** `-tags embed` フラグはフロントエンドをバイナリに組み込みます。このフラグがない場合、バイナリはフロントエンド UI を提供しません。
|
||||
|
||||
**`config.yaml` の主要設定:**
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "release"
|
||||
|
||||
database:
|
||||
host: "localhost"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "your_password"
|
||||
dbname: "sub2api"
|
||||
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
|
||||
jwt:
|
||||
secret: "change-this-to-a-secure-random-string"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
user_concurrency: 5
|
||||
user_balance: 0
|
||||
api_key_prefix: "sk-"
|
||||
rate_multiplier: 1.0
|
||||
```
|
||||
|
||||
### Sora ステータス(一時的に利用不可)
|
||||
|
||||
> ⚠️ Sora 関連の機能は、上流統合およびメディア配信の技術的問題により一時的に利用できません。
|
||||
> 現時点では本番環境で Sora に依存しないでください。
|
||||
> 既存の `gateway.sora_*` 設定キーは予約されていますが、これらの問題が解決されるまで有効にならない場合があります。
|
||||
|
||||
`config.yaml` では追加のセキュリティ関連オプションも利用できます:
|
||||
|
||||
- `cors.allowed_origins` - CORS 許可リスト
|
||||
- `security.url_allowlist` - 上流/価格/CRS ホストの許可リスト
|
||||
- `security.url_allowlist.enabled` - URL バリデーションの無効化(注意して使用)
|
||||
- `security.url_allowlist.allow_insecure_http` - バリデーション無効時に HTTP URL を許可
|
||||
- `security.url_allowlist.allow_private_hosts` - プライベート/ローカル IP アドレスを許可
|
||||
- `security.response_headers.enabled` - 設定可能なレスポンスヘッダーフィルタリングを有効化(無効時はデフォルトの許可リストを使用)
|
||||
- `security.csp` - Content-Security-Policy ヘッダーの制御
|
||||
- `billing.circuit_breaker` - 課金エラー時にフェイルクローズ
|
||||
- `server.trusted_proxies` - X-Forwarded-For パースの有効化
|
||||
- `turnstile.required` - リリースモードでの Turnstile 必須化
|
||||
|
||||
**⚠️ セキュリティ警告: HTTP URL 設定**
|
||||
|
||||
`security.url_allowlist.enabled=false` の場合、システムはデフォルトで最小限の URL バリデーションを行い、**HTTP URL を拒否**して HTTPS のみを許可します。HTTP URL を許可するには(開発環境や内部テスト用など)、以下を明示的に設定する必要があります:
|
||||
|
||||
```yaml
|
||||
security:
|
||||
url_allowlist:
|
||||
enabled: false # 許可リストチェックを無効化
|
||||
allow_insecure_http: true # HTTP URL を許可(⚠️ セキュリティリスクあり)
|
||||
```
|
||||
|
||||
**または環境変数で設定:**
|
||||
|
||||
```bash
|
||||
SECURITY_URL_ALLOWLIST_ENABLED=false
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=true
|
||||
```
|
||||
|
||||
**HTTP を許可するリスク:**
|
||||
- API キーとデータが**平文**で送信される(傍受の危険性)
|
||||
- **中間者攻撃(MITM)**を受けやすい
|
||||
- **本番環境には不適切**
|
||||
|
||||
**HTTP を使用すべき場面:**
|
||||
- ✅ ローカルサーバーでの開発・テスト(http://localhost)
|
||||
- ✅ 信頼できるエンドポイントを持つ内部ネットワーク
|
||||
- ✅ HTTPS 取得前のアカウント接続テスト
|
||||
- ❌ 本番環境(HTTPS のみを使用)
|
||||
|
||||
**この設定なしで表示されるエラー例:**
|
||||
```
|
||||
Invalid base URL: invalid url scheme: http
|
||||
```
|
||||
|
||||
URL バリデーションまたはレスポンスヘッダーフィルタリングを無効にする場合は、ネットワーク層を強化してください:
|
||||
- 上流ドメイン/IP のエグレス許可リストを適用
|
||||
- プライベート/ループバック/リンクローカル範囲をブロック
|
||||
- TLS のみのアウトバウンドトラフィックを強制
|
||||
- プロキシで機密性の高い上流レスポンスヘッダーを除去
|
||||
|
||||
```bash
|
||||
# 6. アプリケーションを実行
|
||||
./sub2api
|
||||
```
|
||||
|
||||
#### 開発モード
|
||||
|
||||
```bash
|
||||
# バックエンド(ホットリロード付き)
|
||||
cd backend
|
||||
go run ./cmd/server
|
||||
|
||||
# フロントエンド(ホットリロード付き)
|
||||
cd frontend
|
||||
pnpm run dev
|
||||
```
|
||||
|
||||
#### コード生成
|
||||
|
||||
`backend/ent/schema` を編集した場合、Ent + Wire を再生成してください:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
go generate ./ent
|
||||
go generate ./cmd/server
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## シンプルモード
|
||||
|
||||
シンプルモードは、フル SaaS 機能を必要とせず、素早くアクセスしたい個人開発者や社内チーム向けに設計されています。
|
||||
|
||||
- 有効化: 環境変数 `RUN_MODE=simple` を設定
|
||||
- 違い: SaaS 関連機能を非表示にし、課金プロセスをスキップ
|
||||
- セキュリティに関する注意: 本番環境では `SIMPLE_MODE_CONFIRM=true` も設定する必要があります
|
||||
|
||||
---
|
||||
|
||||
## Antigravity サポート
|
||||
|
||||
Sub2API は [Antigravity](https://antigravity.so/) アカウントをサポートしています。認証後、Claude および Gemini モデル用の専用エンドポイントが利用可能になります。
|
||||
|
||||
### 専用エンドポイント
|
||||
|
||||
| エンドポイント | モデル |
|
||||
|----------|-------|
|
||||
| `/antigravity/v1/messages` | Claude モデル |
|
||||
| `/antigravity/v1beta/` | Gemini モデル |
|
||||
|
||||
### Claude Code の設定
|
||||
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL="http://localhost:8080/antigravity"
|
||||
export ANTHROPIC_AUTH_TOKEN="sk-xxx"
|
||||
```
|
||||
|
||||
### ハイブリッドスケジューリングモード
|
||||
|
||||
Antigravity アカウントはオプションの**ハイブリッドスケジューリング**をサポートしています。有効にすると、汎用エンドポイント `/v1/messages` および `/v1beta/` も Antigravity アカウントにリクエストをルーティングします。
|
||||
|
||||
> **⚠️ 警告**: Anthropic Claude と Antigravity Claude は**同じ会話コンテキスト内で混在させることはできません**。グループを使用して適切に分離してください。
|
||||
|
||||
### 既知の問題
|
||||
|
||||
Claude Code では、Plan Mode を自動的に終了できません。(通常、ネイティブの Claude API を使用する場合、計画が完了すると Claude Code はユーザーに計画を承認または拒否するオプションをポップアップ表示します。)
|
||||
|
||||
**回避策**: `Shift + Tab` を押して手動で Plan Mode を終了し、計画を承認または拒否するためのレスポンスを入力してください。
|
||||
|
||||
---
|
||||
|
||||
## プロジェクト構成
|
||||
|
||||
```
|
||||
sub2api/
|
||||
├── backend/ # Go バックエンドサービス
|
||||
│ ├── cmd/server/ # アプリケーションエントリ
|
||||
│ ├── internal/ # 内部モジュール
|
||||
│ │ ├── config/ # 設定
|
||||
│ │ ├── model/ # データモデル
|
||||
│ │ ├── service/ # ビジネスロジック
|
||||
│ │ ├── handler/ # HTTP ハンドラー
|
||||
│ │ └── gateway/ # API ゲートウェイコア
|
||||
│ └── resources/ # 静的リソース
|
||||
│
|
||||
├── frontend/ # Vue 3 フロントエンド
|
||||
│ └── src/
|
||||
│ ├── api/ # API 呼び出し
|
||||
│ ├── stores/ # 状態管理
|
||||
│ ├── views/ # ページコンポーネント
|
||||
│ └── components/ # 再利用可能なコンポーネント
|
||||
│
|
||||
└── deploy/ # デプロイファイル
|
||||
├── docker-compose.yml # Docker Compose 設定
|
||||
├── .env.example # Docker Compose 用環境変数
|
||||
├── config.example.yaml # バイナリデプロイ用フル設定ファイル
|
||||
└── install.sh # ワンクリックインストールスクリプト
|
||||
```
|
||||
|
||||
## 免責事項
|
||||
|
||||
> **本プロジェクトをご利用の前に、以下をよくお読みください:**
|
||||
>
|
||||
> :rotating_light: **利用規約違反のリスク**: 本プロジェクトの使用は Anthropic の利用規約に違反する可能性があります。使用前に Anthropic のユーザー契約をよくお読みください。本プロジェクトの使用に起因するすべてのリスクは、ユーザー自身が負うものとします。
|
||||
>
|
||||
> :book: **免責事項**: 本プロジェクトは技術的な学習および研究目的のみで提供されています。作者は、本プロジェクトの使用によるアカウント停止、サービス中断、その他の損失について一切の責任を負いません。
|
||||
|
||||
---
|
||||
|
||||
## スター履歴
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## ライセンス
|
||||
|
||||
MIT License
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**このプロジェクトが役に立ったら、ぜひスターをお願いします!**
|
||||
|
||||
</div>
|
||||
BIN
assets/partners/logos/packycode.png
Normal file
BIN
assets/partners/logos/packycode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.1 KiB |
BIN
assets/partners/logos/pincc-logo.png
Normal file
BIN
assets/partners/logos/pincc-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 171 KiB |
@@ -1 +1 @@
|
||||
0.1.88
|
||||
0.1.107
|
||||
|
||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// Privacy client factory for OpenAI training opt-out
|
||||
providePrivacyClientFactory,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
|
||||
emailCache := repository.NewEmailCache(redisClient)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
@@ -81,6 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
@@ -100,19 +102,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||
soraAccountRepository := repository.NewSoraAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
@@ -122,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
@@ -129,21 +132,31 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -160,20 +173,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
soraS3Storage := service.NewSoraS3Storage(settingService)
|
||||
settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
|
||||
soraGenerationRepository := repository.NewSoraGenerationRepository(db)
|
||||
soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
|
||||
soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
@@ -194,27 +204,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
tlsFingerprintProfileHandler := admin.NewTLSFingerprintProfileHandler(tlsFingerprintProfileService)
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
|
||||
soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
|
||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
||||
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
@@ -225,12 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -245,6 +251,10 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -261,7 +271,6 @@ func provideCleanup(
|
||||
opsCleanup *service.OpsCleanupService,
|
||||
opsScheduledReport *service.OpsScheduledReportService,
|
||||
opsSystemLogSink *service.OpsSystemLogSink,
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
@@ -279,6 +288,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -308,12 +318,6 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"SoraMediaCleanupService", func() error {
|
||||
if soraMediaCleanup != nil {
|
||||
soraMediaCleanup.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"OpsAlertEvaluatorService", func() error {
|
||||
if opsAlertEvaluator != nil {
|
||||
opsAlertEvaluator.Stop()
|
||||
@@ -414,6 +418,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -57,7 +57,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
&service.OpsCleanupService{},
|
||||
&service.OpsScheduledReportService{},
|
||||
opsSystemLogSinkSvc,
|
||||
&service.SoraMediaCleanupService{},
|
||||
schedulerSnapshotSvc,
|
||||
tokenRefreshSvc,
|
||||
accountExpirySvc,
|
||||
@@ -75,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -73,6 +74,8 @@ type Client struct {
|
||||
SecuritySecret *SecuritySecretClient
|
||||
// Setting is the client for interacting with the Setting builders.
|
||||
Setting *SettingClient
|
||||
// TLSFingerprintProfile is the client for interacting with the TLSFingerprintProfile builders.
|
||||
TLSFingerprintProfile *TLSFingerprintProfileClient
|
||||
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
|
||||
UsageCleanupTask *UsageCleanupTaskClient
|
||||
// UsageLog is the client for interacting with the UsageLog builders.
|
||||
@@ -112,6 +115,7 @@ func (c *Client) init() {
|
||||
c.RedeemCode = NewRedeemCodeClient(c.config)
|
||||
c.SecuritySecret = NewSecuritySecretClient(c.config)
|
||||
c.Setting = NewSettingClient(c.config)
|
||||
c.TLSFingerprintProfile = NewTLSFingerprintProfileClient(c.config)
|
||||
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
|
||||
c.UsageLog = NewUsageLogClient(c.config)
|
||||
c.User = NewUserClient(c.config)
|
||||
@@ -225,6 +229,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
RedeemCode: NewRedeemCodeClient(cfg),
|
||||
SecuritySecret: NewSecuritySecretClient(cfg),
|
||||
Setting: NewSettingClient(cfg),
|
||||
TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
|
||||
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
|
||||
UsageLog: NewUsageLogClient(cfg),
|
||||
User: NewUserClient(cfg),
|
||||
@@ -265,6 +270,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
||||
RedeemCode: NewRedeemCodeClient(cfg),
|
||||
SecuritySecret: NewSecuritySecretClient(cfg),
|
||||
Setting: NewSettingClient(cfg),
|
||||
TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
|
||||
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
|
||||
UsageLog: NewUsageLogClient(cfg),
|
||||
User: NewUserClient(cfg),
|
||||
@@ -304,8 +310,9 @@ func (c *Client) Use(hooks ...Hook) {
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
||||
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
|
||||
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
|
||||
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
} {
|
||||
n.Use(hooks...)
|
||||
}
|
||||
@@ -318,8 +325,9 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
||||
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
|
||||
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
|
||||
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
} {
|
||||
n.Intercept(interceptors...)
|
||||
}
|
||||
@@ -356,6 +364,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
||||
return c.SecuritySecret.mutate(ctx, m)
|
||||
case *SettingMutation:
|
||||
return c.Setting.mutate(ctx, m)
|
||||
case *TLSFingerprintProfileMutation:
|
||||
return c.TLSFingerprintProfile.mutate(ctx, m)
|
||||
case *UsageCleanupTaskMutation:
|
||||
return c.UsageCleanupTask.mutate(ctx, m)
|
||||
case *UsageLogMutation:
|
||||
@@ -2612,6 +2622,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
|
||||
}
|
||||
}
|
||||
|
||||
// TLSFingerprintProfileClient is a client for the TLSFingerprintProfile schema.
|
||||
type TLSFingerprintProfileClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewTLSFingerprintProfileClient returns a client for the TLSFingerprintProfile from the given config.
|
||||
func NewTLSFingerprintProfileClient(c config) *TLSFingerprintProfileClient {
|
||||
return &TLSFingerprintProfileClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `tlsfingerprintprofile.Hooks(f(g(h())))`.
|
||||
func (c *TLSFingerprintProfileClient) Use(hooks ...Hook) {
|
||||
c.hooks.TLSFingerprintProfile = append(c.hooks.TLSFingerprintProfile, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `tlsfingerprintprofile.Intercept(f(g(h())))`.
|
||||
func (c *TLSFingerprintProfileClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.TLSFingerprintProfile = append(c.inters.TLSFingerprintProfile, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a TLSFingerprintProfile entity.
|
||||
func (c *TLSFingerprintProfileClient) Create() *TLSFingerprintProfileCreate {
|
||||
mutation := newTLSFingerprintProfileMutation(c.config, OpCreate)
|
||||
return &TLSFingerprintProfileCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of TLSFingerprintProfile entities.
|
||||
func (c *TLSFingerprintProfileClient) CreateBulk(builders ...*TLSFingerprintProfileCreate) *TLSFingerprintProfileCreateBulk {
|
||||
return &TLSFingerprintProfileCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *TLSFingerprintProfileClient) MapCreateBulk(slice any, setFunc func(*TLSFingerprintProfileCreate, int)) *TLSFingerprintProfileCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &TLSFingerprintProfileCreateBulk{err: fmt.Errorf("calling to TLSFingerprintProfileClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*TLSFingerprintProfileCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &TLSFingerprintProfileCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for TLSFingerprintProfile.
|
||||
func (c *TLSFingerprintProfileClient) Update() *TLSFingerprintProfileUpdate {
|
||||
mutation := newTLSFingerprintProfileMutation(c.config, OpUpdate)
|
||||
return &TLSFingerprintProfileUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *TLSFingerprintProfileClient) UpdateOne(_m *TLSFingerprintProfile) *TLSFingerprintProfileUpdateOne {
|
||||
mutation := newTLSFingerprintProfileMutation(c.config, OpUpdateOne, withTLSFingerprintProfile(_m))
|
||||
return &TLSFingerprintProfileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *TLSFingerprintProfileClient) UpdateOneID(id int64) *TLSFingerprintProfileUpdateOne {
|
||||
mutation := newTLSFingerprintProfileMutation(c.config, OpUpdateOne, withTLSFingerprintProfileID(id))
|
||||
return &TLSFingerprintProfileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for TLSFingerprintProfile.
|
||||
func (c *TLSFingerprintProfileClient) Delete() *TLSFingerprintProfileDelete {
|
||||
mutation := newTLSFingerprintProfileMutation(c.config, OpDelete)
|
||||
return &TLSFingerprintProfileDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *TLSFingerprintProfileClient) DeleteOne(_m *TLSFingerprintProfile) *TLSFingerprintProfileDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *TLSFingerprintProfileClient) DeleteOneID(id int64) *TLSFingerprintProfileDeleteOne {
|
||||
builder := c.Delete().Where(tlsfingerprintprofile.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &TLSFingerprintProfileDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for TLSFingerprintProfile.
|
||||
func (c *TLSFingerprintProfileClient) Query() *TLSFingerprintProfileQuery {
|
||||
return &TLSFingerprintProfileQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypeTLSFingerprintProfile},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a TLSFingerprintProfile entity by its id.
|
||||
func (c *TLSFingerprintProfileClient) Get(ctx context.Context, id int64) (*TLSFingerprintProfile, error) {
|
||||
return c.Query().Where(tlsfingerprintprofile.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *TLSFingerprintProfileClient) GetX(ctx context.Context, id int64) *TLSFingerprintProfile {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *TLSFingerprintProfileClient) Hooks() []Hook {
|
||||
return c.hooks.TLSFingerprintProfile
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *TLSFingerprintProfileClient) Interceptors() []Interceptor {
|
||||
return c.inters.TLSFingerprintProfile
|
||||
}
|
||||
|
||||
func (c *TLSFingerprintProfileClient) mutate(ctx context.Context, m *TLSFingerprintProfileMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&TLSFingerprintProfileCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&TLSFingerprintProfileUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&TLSFingerprintProfileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&TLSFingerprintProfileDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown TLSFingerprintProfile mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// UsageCleanupTaskClient is a client for the UsageCleanupTask schema.
|
||||
type UsageCleanupTaskClient struct {
|
||||
config
|
||||
@@ -3889,16 +4032,16 @@ type (
|
||||
hooks struct {
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
||||
ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
|
||||
Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User,
|
||||
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||
UserSubscription []ent.Hook
|
||||
Proxy, RedeemCode, SecuritySecret, Setting, TLSFingerprintProfile,
|
||||
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
|
||||
UserAttributeValue, UserSubscription []ent.Hook
|
||||
}
|
||||
inters struct {
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
||||
ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
|
||||
Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User,
|
||||
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||
UserSubscription []ent.Interceptor
|
||||
Proxy, RedeemCode, SecuritySecret, Setting, TLSFingerprintProfile,
|
||||
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
|
||||
UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -107,6 +108,7 @@ func checkColumn(t, c string) error {
|
||||
redeemcode.Table: redeemcode.ValidColumn,
|
||||
securitysecret.Table: securitysecret.ValidColumn,
|
||||
setting.Table: setting.ValidColumn,
|
||||
tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
|
||||
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
|
||||
usagelog.Table: usagelog.ValidColumn,
|
||||
user.Table: user.ValidColumn,
|
||||
|
||||
@@ -52,16 +52,6 @@ type Group struct {
|
||||
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
||||
// ImagePrice4k holds the value of the "image_price_4k" field.
|
||||
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
||||
// SoraImagePrice360 holds the value of the "sora_image_price_360" field.
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||
// SoraImagePrice540 holds the value of the "sora_image_price_540" field.
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||
// SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
|
||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||
// 是否仅允许 Claude Code 客户端
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
@@ -80,6 +70,10 @@ type Group struct {
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 仅允许非 apikey 类型账号关联到此分组
|
||||
RequireOauthOnly bool `json:"require_oauth_only,omitempty"`
|
||||
// 调度时仅允许 privacy 已成功设置的账号
|
||||
RequirePrivacySet bool `json:"require_privacy_set,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
@@ -190,11 +184,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -331,40 +325,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
_m.ImagePrice4k = new(float64)
|
||||
*_m.ImagePrice4k = value.Float64
|
||||
}
|
||||
case group.FieldSoraImagePrice360:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraImagePrice360 = new(float64)
|
||||
*_m.SoraImagePrice360 = value.Float64
|
||||
}
|
||||
case group.FieldSoraImagePrice540:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraImagePrice540 = new(float64)
|
||||
*_m.SoraImagePrice540 = value.Float64
|
||||
}
|
||||
case group.FieldSoraVideoPricePerRequest:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraVideoPricePerRequest = new(float64)
|
||||
*_m.SoraVideoPricePerRequest = value.Float64
|
||||
}
|
||||
case group.FieldSoraVideoPricePerRequestHd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraVideoPricePerRequestHd = new(float64)
|
||||
*_m.SoraVideoPricePerRequestHd = value.Float64
|
||||
}
|
||||
case group.FieldSoraStorageQuotaBytes:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraStorageQuotaBytes = value.Int64
|
||||
}
|
||||
case group.FieldClaudeCodeOnly:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
||||
@@ -425,6 +385,18 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.AllowMessagesDispatch = value.Bool
|
||||
}
|
||||
case group.FieldRequireOauthOnly:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field require_oauth_only", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RequireOauthOnly = value.Bool
|
||||
}
|
||||
case group.FieldRequirePrivacySet:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field require_privacy_set", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RequirePrivacySet = value.Bool
|
||||
}
|
||||
case group.FieldDefaultMappedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i])
|
||||
@@ -574,29 +546,6 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.SoraImagePrice360; v != nil {
|
||||
builder.WriteString("sora_image_price_360=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.SoraImagePrice540; v != nil {
|
||||
builder.WriteString("sora_image_price_540=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.SoraVideoPricePerRequest; v != nil {
|
||||
builder.WriteString("sora_video_price_per_request=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.SoraVideoPricePerRequestHd; v != nil {
|
||||
builder.WriteString("sora_video_price_per_request_hd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sora_storage_quota_bytes=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("claude_code_only=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
||||
builder.WriteString(", ")
|
||||
@@ -628,6 +577,12 @@ func (_m *Group) String() string {
|
||||
builder.WriteString("allow_messages_dispatch=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("require_oauth_only=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RequireOauthOnly))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("require_privacy_set=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RequirePrivacySet))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteByte(')')
|
||||
|
||||
@@ -49,16 +49,6 @@ const (
|
||||
FieldImagePrice2k = "image_price_2k"
|
||||
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
||||
FieldImagePrice4k = "image_price_4k"
|
||||
// FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
|
||||
FieldSoraImagePrice360 = "sora_image_price_360"
|
||||
// FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
|
||||
FieldSoraImagePrice540 = "sora_image_price_540"
|
||||
// FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
|
||||
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
|
||||
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
|
||||
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
|
||||
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
|
||||
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
|
||||
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
||||
FieldClaudeCodeOnly = "claude_code_only"
|
||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||
@@ -77,6 +67,10 @@ const (
|
||||
FieldSortOrder = "sort_order"
|
||||
// FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database.
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldRequireOauthOnly holds the string denoting the require_oauth_only field in the database.
|
||||
FieldRequireOauthOnly = "require_oauth_only"
|
||||
// FieldRequirePrivacySet holds the string denoting the require_privacy_set field in the database.
|
||||
FieldRequirePrivacySet = "require_privacy_set"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
@@ -171,11 +165,6 @@ var Columns = []string{
|
||||
FieldImagePrice1k,
|
||||
FieldImagePrice2k,
|
||||
FieldImagePrice4k,
|
||||
FieldSoraImagePrice360,
|
||||
FieldSoraImagePrice540,
|
||||
FieldSoraVideoPricePerRequest,
|
||||
FieldSoraVideoPricePerRequestHd,
|
||||
FieldSoraStorageQuotaBytes,
|
||||
FieldClaudeCodeOnly,
|
||||
FieldFallbackGroupID,
|
||||
FieldFallbackGroupIDOnInvalidRequest,
|
||||
@@ -185,6 +174,8 @@ var Columns = []string{
|
||||
FieldSupportedModelScopes,
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldRequireOauthOnly,
|
||||
FieldRequirePrivacySet,
|
||||
FieldDefaultMappedModel,
|
||||
}
|
||||
|
||||
@@ -241,8 +232,6 @@ var (
|
||||
SubscriptionTypeValidator func(string) error
|
||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||
DefaultDefaultValidityDays int
|
||||
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
|
||||
DefaultSoraStorageQuotaBytes int64
|
||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||
DefaultClaudeCodeOnly bool
|
||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||
@@ -255,6 +244,10 @@ var (
|
||||
DefaultSortOrder int
|
||||
// DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field.
|
||||
DefaultAllowMessagesDispatch bool
|
||||
// DefaultRequireOauthOnly holds the default value on creation for the "require_oauth_only" field.
|
||||
DefaultRequireOauthOnly bool
|
||||
// DefaultRequirePrivacySet holds the default value on creation for the "require_privacy_set" field.
|
||||
DefaultRequirePrivacySet bool
|
||||
// DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field.
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
@@ -354,31 +347,6 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
|
||||
func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
|
||||
func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
|
||||
func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
|
||||
func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
|
||||
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
||||
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
||||
@@ -414,6 +382,16 @@ func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRequireOauthOnly orders the results by the require_oauth_only field.
|
||||
func ByRequireOauthOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRequireOauthOnly, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRequirePrivacySet orders the results by the require_privacy_set field.
|
||||
func ByRequirePrivacySet(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRequirePrivacySet, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDefaultMappedModel orders the results by the default_mapped_model field.
|
||||
func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
|
||||
@@ -140,31 +140,6 @@ func ImagePrice4k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
|
||||
func SoraImagePrice360(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
|
||||
func SoraImagePrice540(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
|
||||
func SoraVideoPricePerRequest(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
|
||||
func SoraVideoPricePerRequestHd(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
|
||||
func SoraStorageQuotaBytes(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
||||
func ClaudeCodeOnly(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||
@@ -200,6 +175,16 @@ func AllowMessagesDispatch(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// RequireOauthOnly applies equality check predicate on the "require_oauth_only" field. It's identical to RequireOauthOnlyEQ.
|
||||
func RequireOauthOnly(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v))
|
||||
}
|
||||
|
||||
// RequirePrivacySet applies equality check predicate on the "require_privacy_set" field. It's identical to RequirePrivacySetEQ.
|
||||
func RequirePrivacySet(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ.
|
||||
func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
@@ -1060,246 +1045,6 @@ func ImagePrice4kNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
||||
}
|
||||
|
||||
// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360EQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360NEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360In(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
|
||||
}
|
||||
|
||||
// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
|
||||
}
|
||||
|
||||
// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360GT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360GTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360LT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360LTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360IsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
|
||||
}
|
||||
|
||||
// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
|
||||
func SoraImagePrice360NotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
|
||||
}
|
||||
|
||||
// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540EQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540NEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540In(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
|
||||
}
|
||||
|
||||
// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
|
||||
}
|
||||
|
||||
// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540GT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540GTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540LT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540LTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
|
||||
}
|
||||
|
||||
// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540IsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
|
||||
}
|
||||
|
||||
// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
|
||||
func SoraImagePrice540NotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
|
||||
func SoraVideoPricePerRequestNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
|
||||
}
|
||||
|
||||
// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
|
||||
func SoraVideoPricePerRequestHdNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesGT(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesLT(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
||||
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||
@@ -1490,6 +1235,26 @@ func AllowMessagesDispatchNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// RequireOauthOnlyEQ applies the EQ predicate on the "require_oauth_only" field.
|
||||
func RequireOauthOnlyEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v))
|
||||
}
|
||||
|
||||
// RequireOauthOnlyNEQ applies the NEQ predicate on the "require_oauth_only" field.
|
||||
func RequireOauthOnlyNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldRequireOauthOnly, v))
|
||||
}
|
||||
|
||||
// RequirePrivacySetEQ applies the EQ predicate on the "require_privacy_set" field.
|
||||
func RequirePrivacySetEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v))
|
||||
}
|
||||
|
||||
// RequirePrivacySetNEQ applies the NEQ predicate on the "require_privacy_set" field.
|
||||
func RequirePrivacySetNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldRequirePrivacySet, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
|
||||
@@ -258,76 +258,6 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
||||
func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
|
||||
_c.mutation.SetSoraImagePrice360(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraImagePrice360(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
||||
func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
|
||||
_c.mutation.SetSoraImagePrice540(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraImagePrice540(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
||||
func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
|
||||
_c.mutation.SetSoraVideoPricePerRequest(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraVideoPricePerRequest(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
||||
func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
|
||||
_c.mutation.SetSoraVideoPricePerRequestHd(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraVideoPricePerRequestHd(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
|
||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraStorageQuotaBytes(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
@@ -438,6 +368,34 @@ func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRequireOauthOnly sets the "require_oauth_only" field.
|
||||
func (_c *GroupCreate) SetRequireOauthOnly(v bool) *GroupCreate {
|
||||
_c.mutation.SetRequireOauthOnly(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableRequireOauthOnly(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetRequireOauthOnly(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRequirePrivacySet sets the "require_privacy_set" field.
|
||||
func (_c *GroupCreate) SetRequirePrivacySet(v bool) *GroupCreate {
|
||||
_c.mutation.SetRequirePrivacySet(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableRequirePrivacySet(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetRequirePrivacySet(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate {
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
@@ -617,10 +575,6 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultValidityDays
|
||||
_c.mutation.SetDefaultValidityDays(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||
v := group.DefaultSoraStorageQuotaBytes
|
||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
v := group.DefaultClaudeCodeOnly
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
@@ -645,6 +599,14 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultAllowMessagesDispatch
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RequireOauthOnly(); !ok {
|
||||
v := group.DefaultRequireOauthOnly
|
||||
_c.mutation.SetRequireOauthOnly(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RequirePrivacySet(); !ok {
|
||||
v := group.DefaultRequirePrivacySet
|
||||
_c.mutation.SetRequirePrivacySet(v)
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
@@ -701,9 +663,6 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||
}
|
||||
@@ -722,6 +681,12 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RequireOauthOnly(); !ok {
|
||||
return &ValidationError{Name: "require_oauth_only", err: errors.New(`ent: missing required field "Group.require_oauth_only"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RequirePrivacySet(); !ok {
|
||||
return &ValidationError{Name: "require_privacy_set", err: errors.New(`ent: missing required field "Group.require_privacy_set"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)}
|
||||
}
|
||||
@@ -825,26 +790,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
_node.ImagePrice4k = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraImagePrice360(); ok {
|
||||
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
||||
_node.SoraImagePrice360 = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraImagePrice540(); ok {
|
||||
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
||||
_node.SoraImagePrice540 = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
|
||||
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
||||
_node.SoraVideoPricePerRequest = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
|
||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
||||
_node.SoraVideoPricePerRequestHd = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
|
||||
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
_node.SoraStorageQuotaBytes = value
|
||||
}
|
||||
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
_node.ClaudeCodeOnly = value
|
||||
@@ -881,6 +826,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
_node.AllowMessagesDispatch = value
|
||||
}
|
||||
if value, ok := _c.mutation.RequireOauthOnly(); ok {
|
||||
_spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value)
|
||||
_node.RequireOauthOnly = value
|
||||
}
|
||||
if value, ok := _c.mutation.RequirePrivacySet(); ok {
|
||||
_spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value)
|
||||
_node.RequirePrivacySet = value
|
||||
}
|
||||
if value, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
@@ -1329,120 +1282,6 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
||||
func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldSoraImagePrice360, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSoraImagePrice360)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
|
||||
func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldSoraImagePrice360, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
||||
func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
|
||||
u.SetNull(group.FieldSoraImagePrice360)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
||||
func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldSoraImagePrice540, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSoraImagePrice540)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
|
||||
func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldSoraImagePrice540, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
||||
func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
|
||||
u.SetNull(group.FieldSoraImagePrice540)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldSoraVideoPricePerRequest, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSoraVideoPricePerRequest)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldSoraVideoPricePerRequest, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
|
||||
u.SetNull(group.FieldSoraVideoPricePerRequest)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldSoraVideoPricePerRequestHd, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldSoraVideoPricePerRequestHd, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
|
||||
u.SetNull(group.FieldSoraVideoPricePerRequestHd)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
|
||||
u.Set(group.FieldSoraStorageQuotaBytes, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSoraStorageQuotaBytes)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||
func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
|
||||
u.Add(group.FieldSoraStorageQuotaBytes, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldClaudeCodeOnly, v)
|
||||
@@ -1587,6 +1426,30 @@ func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRequireOauthOnly sets the "require_oauth_only" field.
|
||||
func (u *GroupUpsert) SetRequireOauthOnly(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldRequireOauthOnly, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateRequireOauthOnly() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldRequireOauthOnly)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRequirePrivacySet sets the "require_privacy_set" field.
|
||||
func (u *GroupUpsert) SetRequirePrivacySet(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldRequirePrivacySet, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateRequirePrivacySet() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldRequirePrivacySet)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert {
|
||||
u.Set(group.FieldDefaultMappedModel, v)
|
||||
@@ -1980,139 +1843,6 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
||||
func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraImagePrice360(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
|
||||
func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraImagePrice360(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraImagePrice360()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
||||
func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraImagePrice360()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
||||
func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraImagePrice540(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
|
||||
func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraImagePrice540(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraImagePrice540()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
||||
func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraImagePrice540()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraVideoPricePerRequest(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraVideoPricePerRequest(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraVideoPricePerRequest()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraVideoPricePerRequest()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraVideoPricePerRequestHd(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraVideoPricePerRequestHd(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraVideoPricePerRequestHd()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraVideoPricePerRequestHd()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||
func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraStorageQuotaBytes()
|
||||
})
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
@@ -2281,6 +2011,34 @@ func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequireOauthOnly sets the "require_oauth_only" field.
|
||||
func (u *GroupUpsertOne) SetRequireOauthOnly(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetRequireOauthOnly(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateRequireOauthOnly() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateRequireOauthOnly()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequirePrivacySet sets the "require_privacy_set" field.
|
||||
func (u *GroupUpsertOne) SetRequirePrivacySet(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetRequirePrivacySet(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateRequirePrivacySet() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateRequirePrivacySet()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
@@ -2842,139 +2600,6 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
||||
func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraImagePrice360(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
|
||||
func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraImagePrice360(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraImagePrice360()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
||||
func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraImagePrice360()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
||||
func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraImagePrice540(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
|
||||
func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraImagePrice540(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraImagePrice540()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
||||
func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraImagePrice540()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraVideoPricePerRequest(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraVideoPricePerRequest(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraVideoPricePerRequest()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
||||
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraVideoPricePerRequest()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraVideoPricePerRequestHd(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraVideoPricePerRequestHd(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraVideoPricePerRequestHd()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
||||
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearSoraVideoPricePerRequestHd()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||
func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSoraStorageQuotaBytes()
|
||||
})
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
@@ -3143,6 +2768,34 @@ func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequireOauthOnly sets the "require_oauth_only" field.
|
||||
func (u *GroupUpsertBulk) SetRequireOauthOnly(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetRequireOauthOnly(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateRequireOauthOnly() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateRequireOauthOnly()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequirePrivacySet sets the "require_privacy_set" field.
|
||||
func (u *GroupUpsertBulk) SetRequirePrivacySet(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetRequirePrivacySet(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateRequirePrivacySet() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateRequirePrivacySet()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
|
||||
@@ -355,135 +355,6 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
||||
func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetSoraImagePrice360()
|
||||
_u.mutation.SetSoraImagePrice360(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraImagePrice360(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
|
||||
func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
|
||||
_u.mutation.AddSoraImagePrice360(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
||||
func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
|
||||
_u.mutation.ClearSoraImagePrice360()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
||||
func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetSoraImagePrice540()
|
||||
_u.mutation.SetSoraImagePrice540(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraImagePrice540(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
|
||||
func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
|
||||
_u.mutation.AddSoraImagePrice540(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
||||
func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
|
||||
_u.mutation.ClearSoraImagePrice540()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
||||
func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetSoraVideoPricePerRequest()
|
||||
_u.mutation.SetSoraVideoPricePerRequest(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraVideoPricePerRequest(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
|
||||
func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
|
||||
_u.mutation.AddSoraVideoPricePerRequest(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
||||
func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
|
||||
_u.mutation.ClearSoraVideoPricePerRequest()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
||||
func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetSoraVideoPricePerRequestHd()
|
||||
_u.mutation.SetSoraVideoPricePerRequestHd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraVideoPricePerRequestHd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
|
||||
func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
|
||||
_u.mutation.AddSoraVideoPricePerRequestHd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
||||
func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
|
||||
_u.mutation.ClearSoraVideoPricePerRequestHd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
|
||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraStorageQuotaBytes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||
func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
|
||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
||||
_u.mutation.SetClaudeCodeOnly(v)
|
||||
@@ -639,6 +510,34 @@ func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequireOauthOnly sets the "require_oauth_only" field.
|
||||
func (_u *GroupUpdate) SetRequireOauthOnly(v bool) *GroupUpdate {
|
||||
_u.mutation.SetRequireOauthOnly(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableRequireOauthOnly(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetRequireOauthOnly(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequirePrivacySet sets the "require_privacy_set" field.
|
||||
func (_u *GroupUpdate) SetRequirePrivacySet(v bool) *GroupUpdate {
|
||||
_u.mutation.SetRequirePrivacySet(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableRequirePrivacySet(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetRequirePrivacySet(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
@@ -1054,48 +953,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.ImagePrice4kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraImagePrice360(); ok {
|
||||
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
|
||||
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraImagePrice360Cleared() {
|
||||
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraImagePrice540(); ok {
|
||||
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
|
||||
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraImagePrice540Cleared() {
|
||||
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
|
||||
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
|
||||
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraVideoPricePerRequestCleared() {
|
||||
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
|
||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
|
||||
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
||||
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||
_spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
}
|
||||
@@ -1146,6 +1003,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequireOauthOnly(); ok {
|
||||
_spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequirePrivacySet(); ok {
|
||||
_spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
@@ -1783,135 +1646,6 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
||||
func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetSoraImagePrice360()
|
||||
_u.mutation.SetSoraImagePrice360(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraImagePrice360(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
|
||||
func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddSoraImagePrice360(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
||||
func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
|
||||
_u.mutation.ClearSoraImagePrice360()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
||||
func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetSoraImagePrice540()
|
||||
_u.mutation.SetSoraImagePrice540(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraImagePrice540(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
|
||||
func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddSoraImagePrice540(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
||||
func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
|
||||
_u.mutation.ClearSoraImagePrice540()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
||||
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetSoraVideoPricePerRequest()
|
||||
_u.mutation.SetSoraVideoPricePerRequest(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraVideoPricePerRequest(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
|
||||
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddSoraVideoPricePerRequest(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
||||
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
|
||||
_u.mutation.ClearSoraVideoPricePerRequest()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
||||
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetSoraVideoPricePerRequestHd()
|
||||
_u.mutation.SetSoraVideoPricePerRequestHd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraVideoPricePerRequestHd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
|
||||
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddSoraVideoPricePerRequestHd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
||||
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
|
||||
_u.mutation.ClearSoraVideoPricePerRequestHd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
|
||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraStorageQuotaBytes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||
func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
|
||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetClaudeCodeOnly(v)
|
||||
@@ -2067,6 +1801,34 @@ func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequireOauthOnly sets the "require_oauth_only" field.
|
||||
func (_u *GroupUpdateOne) SetRequireOauthOnly(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetRequireOauthOnly(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableRequireOauthOnly(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRequireOauthOnly(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequirePrivacySet sets the "require_privacy_set" field.
|
||||
func (_u *GroupUpdateOne) SetRequirePrivacySet(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetRequirePrivacySet(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableRequirePrivacySet(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRequirePrivacySet(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
@@ -2512,48 +2274,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if _u.mutation.ImagePrice4kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraImagePrice360(); ok {
|
||||
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
|
||||
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraImagePrice360Cleared() {
|
||||
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraImagePrice540(); ok {
|
||||
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
|
||||
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraImagePrice540Cleared() {
|
||||
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
|
||||
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
|
||||
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraVideoPricePerRequestCleared() {
|
||||
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
|
||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
|
||||
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
||||
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||
_spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
}
|
||||
@@ -2604,6 +2324,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequireOauthOnly(); ok {
|
||||
_spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequirePrivacySet(); ok {
|
||||
_spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
|
||||
@@ -177,6 +177,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
|
||||
}
|
||||
|
||||
// The TLSFingerprintProfileFunc type is an adapter to allow the use of ordinary
|
||||
// function as TLSFingerprintProfile mutator.
|
||||
type TLSFingerprintProfileFunc func(context.Context, *ent.TLSFingerprintProfileMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f TLSFingerprintProfileFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.TLSFingerprintProfileMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TLSFingerprintProfileMutation", m)
|
||||
}
|
||||
|
||||
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary
|
||||
// function as UsageCleanupTask mutator.
|
||||
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error)
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -466,6 +467,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
|
||||
}
|
||||
|
||||
// The TLSFingerprintProfileFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type TLSFingerprintProfileFunc func(context.Context, *ent.TLSFingerprintProfileQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f TLSFingerprintProfileFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.TLSFingerprintProfileQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.TLSFingerprintProfileQuery", q)
|
||||
}
|
||||
|
||||
// The TraverseTLSFingerprintProfile type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraverseTLSFingerprintProfile func(context.Context, *ent.TLSFingerprintProfileQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraverseTLSFingerprintProfile) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraverseTLSFingerprintProfile) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.TLSFingerprintProfileQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.TLSFingerprintProfileQuery", q)
|
||||
}
|
||||
|
||||
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error)
|
||||
|
||||
@@ -686,6 +714,8 @@ func NewQuery(q ent.Query) (Query, error) {
|
||||
return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil
|
||||
case *ent.SettingQuery:
|
||||
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
|
||||
case *ent.TLSFingerprintProfileQuery:
|
||||
return &query[*ent.TLSFingerprintProfileQuery, predicate.TLSFingerprintProfile, tlsfingerprintprofile.OrderOption]{typ: ent.TypeTLSFingerprintProfile, tq: q}, nil
|
||||
case *ent.UsageCleanupTaskQuery:
|
||||
return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil
|
||||
case *ent.UsageLogQuery:
|
||||
|
||||
@@ -395,11 +395,6 @@ var (
|
||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
|
||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
||||
@@ -409,6 +404,8 @@ var (
|
||||
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "require_oauth_only", Type: field.TypeBool, Default: false},
|
||||
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
@@ -445,7 +442,7 @@ var (
|
||||
{
|
||||
Name: "group_sort_order",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{GroupsColumns[30]},
|
||||
Columns: []*schema.Column{GroupsColumns[25]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -673,6 +670,30 @@ var (
|
||||
Columns: SettingsColumns,
|
||||
PrimaryKey: []*schema.Column{SettingsColumns[0]},
|
||||
}
|
||||
// TLSFingerprintProfilesColumns holds the columns for the "tls_fingerprint_profiles" table.
|
||||
TLSFingerprintProfilesColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "name", Type: field.TypeString, Unique: true, Size: 100},
|
||||
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||
{Name: "enable_grease", Type: field.TypeBool, Default: false},
|
||||
{Name: "cipher_suites", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "curves", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "point_formats", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "signature_algorithms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "alpn_protocols", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "supported_versions", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "key_share_groups", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "psk_modes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "extensions", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
}
|
||||
// TLSFingerprintProfilesTable holds the schema information for the "tls_fingerprint_profiles" table.
|
||||
TLSFingerprintProfilesTable = &schema.Table{
|
||||
Name: "tls_fingerprint_profiles",
|
||||
Columns: TLSFingerprintProfilesColumns,
|
||||
PrimaryKey: []*schema.Column{TLSFingerprintProfilesColumns[0]},
|
||||
}
|
||||
// UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table.
|
||||
UsageCleanupTasksColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
@@ -716,6 +737,12 @@ var (
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "request_id", Type: field.TypeString, Size: 64},
|
||||
{Name: "model", Type: field.TypeString, Size: 100},
|
||||
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "channel_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500},
|
||||
{Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50},
|
||||
{Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20},
|
||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
|
||||
@@ -738,7 +765,6 @@ var (
|
||||
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
|
||||
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "api_key_id", Type: field.TypeInt64},
|
||||
@@ -755,31 +781,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -788,38 +814,43 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[2]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_requested_model",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[3]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_request_id",
|
||||
Unique: false,
|
||||
@@ -828,17 +859,17 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -859,8 +890,6 @@ var (
|
||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
|
||||
{Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
|
||||
}
|
||||
// UsersTable holds the schema information for the "users" table.
|
||||
UsersTable = &schema.Table{
|
||||
@@ -1104,6 +1133,7 @@ var (
|
||||
RedeemCodesTable,
|
||||
SecuritySecretsTable,
|
||||
SettingsTable,
|
||||
TLSFingerprintProfilesTable,
|
||||
UsageCleanupTasksTable,
|
||||
UsageLogsTable,
|
||||
UsersTable,
|
||||
@@ -1168,6 +1198,9 @@ func init() {
|
||||
SettingsTable.Annotation = &entsql.Annotation{
|
||||
Table: "settings",
|
||||
}
|
||||
TLSFingerprintProfilesTable.Annotation = &entsql.Annotation{
|
||||
Table: "tls_fingerprint_profiles",
|
||||
}
|
||||
UsageCleanupTasksTable.Annotation = &entsql.Annotation{
|
||||
Table: "usage_cleanup_tasks",
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,6 +48,9 @@ type SecuritySecret func(*sql.Selector)
|
||||
// Setting is the predicate function for setting builders.
|
||||
type Setting func(*sql.Selector)
|
||||
|
||||
// TLSFingerprintProfile is the predicate function for tlsfingerprintprofile builders.
|
||||
type TLSFingerprintProfile func(*sql.Selector)
|
||||
|
||||
// UsageCleanupTask is the predicate function for usagecleanuptask builders.
|
||||
type UsageCleanupTask func(*sql.Selector)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -429,36 +430,40 @@ func init() {
|
||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||
// groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
|
||||
groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor()
|
||||
// group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
|
||||
group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
|
||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||
groupDescClaudeCodeOnly := groupFields[19].Descriptor()
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||
groupDescModelRoutingEnabled := groupFields[23].Descriptor()
|
||||
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
|
||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
||||
groupDescMcpXMLInject := groupFields[24].Descriptor()
|
||||
groupDescMcpXMLInject := groupFields[19].Descriptor()
|
||||
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
||||
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
||||
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
||||
groupDescSupportedModelScopes := groupFields[25].Descriptor()
|
||||
groupDescSupportedModelScopes := groupFields[20].Descriptor()
|
||||
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||
// groupDescSortOrder is the schema descriptor for sort_order field.
|
||||
groupDescSortOrder := groupFields[26].Descriptor()
|
||||
groupDescSortOrder := groupFields[21].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
|
||||
groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
|
||||
groupDescRequireOauthOnly := groupFields[23].Descriptor()
|
||||
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
|
||||
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
|
||||
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
|
||||
groupDescRequirePrivacySet := groupFields[24].Descriptor()
|
||||
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
|
||||
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
groupDescDefaultMappedModel := groupFields[25].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
@@ -746,6 +751,43 @@ func init() {
|
||||
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
|
||||
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
tlsfingerprintprofileMixin := schema.TLSFingerprintProfile{}.Mixin()
|
||||
tlsfingerprintprofileMixinFields0 := tlsfingerprintprofileMixin[0].Fields()
|
||||
_ = tlsfingerprintprofileMixinFields0
|
||||
tlsfingerprintprofileFields := schema.TLSFingerprintProfile{}.Fields()
|
||||
_ = tlsfingerprintprofileFields
|
||||
// tlsfingerprintprofileDescCreatedAt is the schema descriptor for created_at field.
|
||||
tlsfingerprintprofileDescCreatedAt := tlsfingerprintprofileMixinFields0[0].Descriptor()
|
||||
// tlsfingerprintprofile.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
tlsfingerprintprofile.DefaultCreatedAt = tlsfingerprintprofileDescCreatedAt.Default.(func() time.Time)
|
||||
// tlsfingerprintprofileDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
tlsfingerprintprofileDescUpdatedAt := tlsfingerprintprofileMixinFields0[1].Descriptor()
|
||||
// tlsfingerprintprofile.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
tlsfingerprintprofile.DefaultUpdatedAt = tlsfingerprintprofileDescUpdatedAt.Default.(func() time.Time)
|
||||
// tlsfingerprintprofile.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
tlsfingerprintprofile.UpdateDefaultUpdatedAt = tlsfingerprintprofileDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
// tlsfingerprintprofileDescName is the schema descriptor for name field.
|
||||
tlsfingerprintprofileDescName := tlsfingerprintprofileFields[0].Descriptor()
|
||||
// tlsfingerprintprofile.NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||
tlsfingerprintprofile.NameValidator = func() func(string) error {
|
||||
validators := tlsfingerprintprofileDescName.Validators
|
||||
fns := [...]func(string) error{
|
||||
validators[0].(func(string) error),
|
||||
validators[1].(func(string) error),
|
||||
}
|
||||
return func(name string) error {
|
||||
for _, fn := range fns {
|
||||
if err := fn(name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// tlsfingerprintprofileDescEnableGrease is the schema descriptor for enable_grease field.
|
||||
tlsfingerprintprofileDescEnableGrease := tlsfingerprintprofileFields[2].Descriptor()
|
||||
// tlsfingerprintprofile.DefaultEnableGrease holds the default value on creation for the enable_grease field.
|
||||
tlsfingerprintprofile.DefaultEnableGrease = tlsfingerprintprofileDescEnableGrease.Default.(bool)
|
||||
usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin()
|
||||
usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields()
|
||||
_ = usagecleanuptaskMixinFields0
|
||||
@@ -821,92 +863,108 @@ func init() {
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// usagelogDescRequestedModel is the schema descriptor for requested_model field.
|
||||
usagelogDescRequestedModel := usagelogFields[5].Descriptor()
|
||||
// usagelog.RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
|
||||
usagelog.RequestedModelValidator = usagelogDescRequestedModel.Validators[0].(func(string) error)
|
||||
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
|
||||
usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
|
||||
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||
// usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field.
|
||||
usagelogDescModelMappingChain := usagelogFields[8].Descriptor()
|
||||
// usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
|
||||
usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error)
|
||||
// usagelogDescBillingTier is the schema descriptor for billing_tier field.
|
||||
usagelogDescBillingTier := usagelogFields[9].Descriptor()
|
||||
// usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
|
||||
usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error)
|
||||
// usagelogDescBillingMode is the schema descriptor for billing_mode field.
|
||||
usagelogDescBillingMode := usagelogFields[10].Descriptor()
|
||||
// usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
|
||||
usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error)
|
||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||
usagelogDescInputTokens := usagelogFields[7].Descriptor()
|
||||
usagelogDescInputTokens := usagelogFields[13].Descriptor()
|
||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
|
||||
usagelogDescOutputTokens := usagelogFields[14].Descriptor()
|
||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
|
||||
usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor()
|
||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
|
||||
usagelogDescCacheReadTokens := usagelogFields[16].Descriptor()
|
||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor()
|
||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor()
|
||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||
usagelogDescInputCost := usagelogFields[13].Descriptor()
|
||||
usagelogDescInputCost := usagelogFields[19].Descriptor()
|
||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||
usagelogDescOutputCost := usagelogFields[14].Descriptor()
|
||||
usagelogDescOutputCost := usagelogFields[20].Descriptor()
|
||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
|
||||
usagelogDescCacheCreationCost := usagelogFields[21].Descriptor()
|
||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
|
||||
usagelogDescCacheReadCost := usagelogFields[22].Descriptor()
|
||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||
usagelogDescTotalCost := usagelogFields[17].Descriptor()
|
||||
usagelogDescTotalCost := usagelogFields[23].Descriptor()
|
||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||
usagelogDescActualCost := usagelogFields[18].Descriptor()
|
||||
usagelogDescActualCost := usagelogFields[24].Descriptor()
|
||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
|
||||
usagelogDescRateMultiplier := usagelogFields[25].Descriptor()
|
||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
||||
usagelogDescBillingType := usagelogFields[27].Descriptor()
|
||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||
// usagelogDescStream is the schema descriptor for stream field.
|
||||
usagelogDescStream := usagelogFields[22].Descriptor()
|
||||
usagelogDescStream := usagelogFields[28].Descriptor()
|
||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
||||
usagelogDescUserAgent := usagelogFields[31].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
||||
usagelogDescIPAddress := usagelogFields[32].Descriptor()
|
||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
||||
usagelogDescImageCount := usagelogFields[33].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
||||
usagelogDescImageSize := usagelogFields[34].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||
usagelogDescMediaType := usagelogFields[29].Descriptor()
|
||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[36].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
@@ -998,14 +1056,6 @@ func init() {
|
||||
userDescTotpEnabled := userFields[9].Descriptor()
|
||||
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||
// userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
|
||||
userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
|
||||
// user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
|
||||
user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
|
||||
// userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
|
||||
userDescSoraStorageUsedBytes := userFields[12].Descriptor()
|
||||
// user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
|
||||
user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
|
||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||
_ = userallowedgroupFields
|
||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||
|
||||
@@ -87,28 +87,6 @@ func (Group) Fields() []ent.Field {
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
|
||||
// Sora 按次计费配置(阶段 1)
|
||||
field.Float("sora_image_price_360").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
field.Float("sora_image_price_540").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
field.Float("sora_video_price_per_request").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
field.Float("sora_video_price_per_request_hd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
|
||||
// Sora 存储配额
|
||||
field.Int64("sora_storage_quota_bytes").
|
||||
Default(0),
|
||||
|
||||
// Claude Code 客户端限制 (added by migration 029)
|
||||
field.Bool("claude_code_only").
|
||||
Default(false).
|
||||
@@ -153,6 +131,12 @@ func (Group) Fields() []ent.Field {
|
||||
field.Bool("allow_messages_dispatch").
|
||||
Default(false).
|
||||
Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"),
|
||||
field.Bool("require_oauth_only").
|
||||
Default(false).
|
||||
Comment("仅允许非 apikey 类型账号关联到此分组"),
|
||||
field.Bool("require_privacy_set").
|
||||
Default(false).
|
||||
Comment("调度时仅允许 privacy 已成功设置的账号"),
|
||||
field.String("default_mapped_model").
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
|
||||
100
backend/ent/schema/tls_fingerprint_profile.go
Normal file
100
backend/ent/schema/tls_fingerprint_profile.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Package schema 定义 Ent ORM 的数据库 schema。
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/field"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfile 定义 TLS 指纹配置模板的 schema。
|
||||
//
|
||||
// TLS 指纹模板用于模拟特定客户端(如 Claude Code / Node.js)的 TLS 握手特征。
|
||||
// 每个模板包含完整的 ClientHello 参数:加密套件、曲线、扩展等。
|
||||
// 通过 Account.Extra.tls_fingerprint_profile_id 绑定到具体账号。
|
||||
type TLSFingerprintProfile struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Annotations 返回 schema 的注解配置。
|
||||
func (TLSFingerprintProfile) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "tls_fingerprint_profiles"},
|
||||
}
|
||||
}
|
||||
|
||||
// Mixin 返回该 schema 使用的混入组件。
|
||||
func (TLSFingerprintProfile) Mixin() []ent.Mixin {
|
||||
return []ent.Mixin{
|
||||
mixins.TimeMixin{},
|
||||
}
|
||||
}
|
||||
|
||||
// Fields 定义 TLS 指纹模板实体的所有字段。
|
||||
func (TLSFingerprintProfile) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
// name: 模板名称,唯一标识
|
||||
field.String("name").
|
||||
MaxLen(100).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
|
||||
// description: 模板描述
|
||||
field.Text("description").
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// enable_grease: 是否启用 GREASE 扩展(Chrome 使用,Node.js 不使用)
|
||||
field.Bool("enable_grease").
|
||||
Default(false),
|
||||
|
||||
// cipher_suites: TLS 加密套件列表(顺序敏感,影响 JA3)
|
||||
field.JSON("cipher_suites", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// curves: 椭圆曲线/支持的组列表
|
||||
field.JSON("curves", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// point_formats: EC 点格式列表
|
||||
field.JSON("point_formats", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// signature_algorithms: 签名算法列表
|
||||
field.JSON("signature_algorithms", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// alpn_protocols: ALPN 协议列表(如 ["http/1.1"])
|
||||
field.JSON("alpn_protocols", []string{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// supported_versions: 支持的 TLS 版本列表(如 [0x0304, 0x0303])
|
||||
field.JSON("supported_versions", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// key_share_groups: Key Share 中发送的曲线组(如 [29] 即 X25519)
|
||||
field.JSON("key_share_groups", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// psk_modes: PSK 密钥交换模式(如 [1] 即 psk_dhe_ke)
|
||||
field.JSON("psk_modes", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// extensions: TLS 扩展类型 ID 列表,按发送顺序排列
|
||||
field.JSON("extensions", []uint16{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,22 @@ func (UsageLog) Fields() []ent.Field {
|
||||
field.String("model").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
// RequestedModel stores the client-requested model name for stable display and analytics.
|
||||
// NULL means historical rows written before requested_model dual-write was introduced.
|
||||
field.String("requested_model").
|
||||
MaxLen(100).
|
||||
Optional().
|
||||
Nillable(),
|
||||
// UpstreamModel stores the actual upstream model name when model mapping
|
||||
// is applied. NULL means no mapping — the requested model was used as-is.
|
||||
field.String("upstream_model").
|
||||
MaxLen(100).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"),
|
||||
field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
|
||||
field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
|
||||
field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"),
|
||||
field.Int64("group_id").
|
||||
Optional().
|
||||
Nillable(),
|
||||
@@ -118,12 +134,6 @@ func (UsageLog) Fields() []ent.Field {
|
||||
MaxLen(10).
|
||||
Optional().
|
||||
Nillable(),
|
||||
// 媒体类型字段(sora 使用)
|
||||
field.String("media_type").
|
||||
MaxLen(16).
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||
field.Bool("cache_ttl_overridden").
|
||||
Default(false),
|
||||
@@ -175,6 +185,7 @@ func (UsageLog) Indexes() []ent.Index {
|
||||
index.Fields("subscription_id"),
|
||||
index.Fields("created_at"),
|
||||
index.Fields("model"),
|
||||
index.Fields("requested_model"),
|
||||
index.Fields("request_id"),
|
||||
// 复合索引用于时间范围查询
|
||||
index.Fields("user_id", "created_at"),
|
||||
|
||||
@@ -72,12 +72,6 @@ func (User) Fields() []ent.Field {
|
||||
field.Time("totp_enabled_at").
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// Sora 存储配额
|
||||
field.Int64("sora_storage_quota_bytes").
|
||||
Default(0),
|
||||
field.Int64("sora_storage_used_bytes").
|
||||
Default(0),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
275
backend/ent/tlsfingerprintprofile.go
Normal file
275
backend/ent/tlsfingerprintprofile.go
Normal file
@@ -0,0 +1,275 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfile is the model entity for the TLSFingerprintProfile schema.
|
||||
type TLSFingerprintProfile struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// UpdatedAt holds the value of the "updated_at" field.
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// Name holds the value of the "name" field.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Description holds the value of the "description" field.
|
||||
Description *string `json:"description,omitempty"`
|
||||
// EnableGrease holds the value of the "enable_grease" field.
|
||||
EnableGrease bool `json:"enable_grease,omitempty"`
|
||||
// CipherSuites holds the value of the "cipher_suites" field.
|
||||
CipherSuites []uint16 `json:"cipher_suites,omitempty"`
|
||||
// Curves holds the value of the "curves" field.
|
||||
Curves []uint16 `json:"curves,omitempty"`
|
||||
// PointFormats holds the value of the "point_formats" field.
|
||||
PointFormats []uint16 `json:"point_formats,omitempty"`
|
||||
// SignatureAlgorithms holds the value of the "signature_algorithms" field.
|
||||
SignatureAlgorithms []uint16 `json:"signature_algorithms,omitempty"`
|
||||
// AlpnProtocols holds the value of the "alpn_protocols" field.
|
||||
AlpnProtocols []string `json:"alpn_protocols,omitempty"`
|
||||
// SupportedVersions holds the value of the "supported_versions" field.
|
||||
SupportedVersions []uint16 `json:"supported_versions,omitempty"`
|
||||
// KeyShareGroups holds the value of the "key_share_groups" field.
|
||||
KeyShareGroups []uint16 `json:"key_share_groups,omitempty"`
|
||||
// PskModes holds the value of the "psk_modes" field.
|
||||
PskModes []uint16 `json:"psk_modes,omitempty"`
|
||||
// Extensions holds the value of the "extensions" field.
|
||||
Extensions []uint16 `json:"extensions,omitempty"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*TLSFingerprintProfile) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case tlsfingerprintprofile.FieldCipherSuites, tlsfingerprintprofile.FieldCurves, tlsfingerprintprofile.FieldPointFormats, tlsfingerprintprofile.FieldSignatureAlgorithms, tlsfingerprintprofile.FieldAlpnProtocols, tlsfingerprintprofile.FieldSupportedVersions, tlsfingerprintprofile.FieldKeyShareGroups, tlsfingerprintprofile.FieldPskModes, tlsfingerprintprofile.FieldExtensions:
|
||||
values[i] = new([]byte)
|
||||
case tlsfingerprintprofile.FieldEnableGrease:
|
||||
values[i] = new(sql.NullBool)
|
||||
case tlsfingerprintprofile.FieldID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case tlsfingerprintprofile.FieldName, tlsfingerprintprofile.FieldDescription:
|
||||
values[i] = new(sql.NullString)
|
||||
case tlsfingerprintprofile.FieldCreatedAt, tlsfingerprintprofile.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the TLSFingerprintProfile fields.
|
||||
func (_m *TLSFingerprintProfile) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case tlsfingerprintprofile.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case tlsfingerprintprofile.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case tlsfingerprintprofile.FieldUpdatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpdatedAt = value.Time
|
||||
}
|
||||
case tlsfingerprintprofile.FieldName:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field name", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Name = value.String
|
||||
}
|
||||
case tlsfingerprintprofile.FieldDescription:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field description", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Description = new(string)
|
||||
*_m.Description = value.String
|
||||
}
|
||||
case tlsfingerprintprofile.FieldEnableGrease:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field enable_grease", values[i])
|
||||
} else if value.Valid {
|
||||
_m.EnableGrease = value.Bool
|
||||
}
|
||||
case tlsfingerprintprofile.FieldCipherSuites:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field cipher_suites", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.CipherSuites); err != nil {
|
||||
return fmt.Errorf("unmarshal field cipher_suites: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldCurves:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field curves", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.Curves); err != nil {
|
||||
return fmt.Errorf("unmarshal field curves: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldPointFormats:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field point_formats", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.PointFormats); err != nil {
|
||||
return fmt.Errorf("unmarshal field point_formats: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldSignatureAlgorithms:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field signature_algorithms", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.SignatureAlgorithms); err != nil {
|
||||
return fmt.Errorf("unmarshal field signature_algorithms: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldAlpnProtocols:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field alpn_protocols", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.AlpnProtocols); err != nil {
|
||||
return fmt.Errorf("unmarshal field alpn_protocols: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldSupportedVersions:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field supported_versions", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.SupportedVersions); err != nil {
|
||||
return fmt.Errorf("unmarshal field supported_versions: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldKeyShareGroups:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field key_share_groups", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.KeyShareGroups); err != nil {
|
||||
return fmt.Errorf("unmarshal field key_share_groups: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldPskModes:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field psk_modes", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.PskModes); err != nil {
|
||||
return fmt.Errorf("unmarshal field psk_modes: %w", err)
|
||||
}
|
||||
}
|
||||
case tlsfingerprintprofile.FieldExtensions:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field extensions", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.Extensions); err != nil {
|
||||
return fmt.Errorf("unmarshal field extensions: %w", err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the TLSFingerprintProfile.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *TLSFingerprintProfile) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this TLSFingerprintProfile.
|
||||
// Note that you need to call TLSFingerprintProfile.Unwrap() before calling this method if this TLSFingerprintProfile
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *TLSFingerprintProfile) Update() *TLSFingerprintProfileUpdateOne {
|
||||
return NewTLSFingerprintProfileClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the TLSFingerprintProfile entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *TLSFingerprintProfile) Unwrap() *TLSFingerprintProfile {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: TLSFingerprintProfile is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *TLSFingerprintProfile) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("TLSFingerprintProfile(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("updated_at=")
|
||||
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("name=")
|
||||
builder.WriteString(_m.Name)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Description; v != nil {
|
||||
builder.WriteString("description=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("enable_grease=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.EnableGrease))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("cipher_suites=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.CipherSuites))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("curves=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Curves))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("point_formats=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.PointFormats))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("signature_algorithms=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SignatureAlgorithms))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("alpn_protocols=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AlpnProtocols))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("supported_versions=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SupportedVersions))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("key_share_groups=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.KeyShareGroups))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("psk_modes=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.PskModes))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("extensions=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Extensions))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// TLSFingerprintProfiles is a parsable slice of TLSFingerprintProfile.
|
||||
type TLSFingerprintProfiles []*TLSFingerprintProfile
|
||||
121
backend/ent/tlsfingerprintprofile/tlsfingerprintprofile.go
Normal file
121
backend/ent/tlsfingerprintprofile/tlsfingerprintprofile.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package tlsfingerprintprofile
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the tlsfingerprintprofile type in the database.
|
||||
Label = "tls_fingerprint_profile"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||
FieldUpdatedAt = "updated_at"
|
||||
// FieldName holds the string denoting the name field in the database.
|
||||
FieldName = "name"
|
||||
// FieldDescription holds the string denoting the description field in the database.
|
||||
FieldDescription = "description"
|
||||
// FieldEnableGrease holds the string denoting the enable_grease field in the database.
|
||||
FieldEnableGrease = "enable_grease"
|
||||
// FieldCipherSuites holds the string denoting the cipher_suites field in the database.
|
||||
FieldCipherSuites = "cipher_suites"
|
||||
// FieldCurves holds the string denoting the curves field in the database.
|
||||
FieldCurves = "curves"
|
||||
// FieldPointFormats holds the string denoting the point_formats field in the database.
|
||||
FieldPointFormats = "point_formats"
|
||||
// FieldSignatureAlgorithms holds the string denoting the signature_algorithms field in the database.
|
||||
FieldSignatureAlgorithms = "signature_algorithms"
|
||||
// FieldAlpnProtocols holds the string denoting the alpn_protocols field in the database.
|
||||
FieldAlpnProtocols = "alpn_protocols"
|
||||
// FieldSupportedVersions holds the string denoting the supported_versions field in the database.
|
||||
FieldSupportedVersions = "supported_versions"
|
||||
// FieldKeyShareGroups holds the string denoting the key_share_groups field in the database.
|
||||
FieldKeyShareGroups = "key_share_groups"
|
||||
// FieldPskModes holds the string denoting the psk_modes field in the database.
|
||||
FieldPskModes = "psk_modes"
|
||||
// FieldExtensions holds the string denoting the extensions field in the database.
|
||||
FieldExtensions = "extensions"
|
||||
// Table holds the table name of the tlsfingerprintprofile in the database.
|
||||
Table = "tls_fingerprint_profiles"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for tlsfingerprintprofile fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldCreatedAt,
|
||||
FieldUpdatedAt,
|
||||
FieldName,
|
||||
FieldDescription,
|
||||
FieldEnableGrease,
|
||||
FieldCipherSuites,
|
||||
FieldCurves,
|
||||
FieldPointFormats,
|
||||
FieldSignatureAlgorithms,
|
||||
FieldAlpnProtocols,
|
||||
FieldSupportedVersions,
|
||||
FieldKeyShareGroups,
|
||||
FieldPskModes,
|
||||
FieldExtensions,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
DefaultUpdatedAt func() time.Time
|
||||
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||
NameValidator func(string) error
|
||||
// DefaultEnableGrease holds the default value on creation for the "enable_grease" field.
|
||||
DefaultEnableGrease bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the TLSFingerprintProfile queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpdatedAt orders the results by the updated_at field.
|
||||
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByName orders the results by the name field.
|
||||
func ByName(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldName, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDescription orders the results by the description field.
|
||||
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByEnableGrease orders the results by the enable_grease field.
|
||||
func ByEnableGrease(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldEnableGrease, opts...).ToFunc()
|
||||
}
|
||||
415
backend/ent/tlsfingerprintprofile/where.go
Normal file
415
backend/ent/tlsfingerprintprofile/where.go
Normal file
@@ -0,0 +1,415 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package tlsfingerprintprofile
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||
func UpdatedAt(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
|
||||
func Name(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
||||
func Description(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldDescription, v))
|
||||
}
|
||||
|
||||
// EnableGrease applies equality check predicate on the "enable_grease" field. It's identical to EnableGreaseEQ.
|
||||
func EnableGrease(v bool) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldEnableGrease, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||
func UpdatedAtEQ(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||
func UpdatedAtNEQ(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||
func UpdatedAtIn(vs ...time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||
func UpdatedAtNotIn(vs ...time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||
func UpdatedAtGT(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||
func UpdatedAtGTE(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||
func UpdatedAtLT(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||
func UpdatedAtLTE(v time.Time) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// NameEQ applies the EQ predicate on the "name" field.
|
||||
func NameEQ(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// NameNEQ applies the NEQ predicate on the "name" field.
|
||||
func NameNEQ(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// NameIn applies the In predicate on the "name" field.
|
||||
func NameIn(vs ...string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIn(FieldName, vs...))
|
||||
}
|
||||
|
||||
// NameNotIn applies the NotIn predicate on the "name" field.
|
||||
func NameNotIn(vs ...string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldName, vs...))
|
||||
}
|
||||
|
||||
// NameGT applies the GT predicate on the "name" field.
|
||||
func NameGT(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGT(FieldName, v))
|
||||
}
|
||||
|
||||
// NameGTE applies the GTE predicate on the "name" field.
|
||||
func NameGTE(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldName, v))
|
||||
}
|
||||
|
||||
// NameLT applies the LT predicate on the "name" field.
|
||||
func NameLT(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLT(FieldName, v))
|
||||
}
|
||||
|
||||
// NameLTE applies the LTE predicate on the "name" field.
|
||||
func NameLTE(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldName, v))
|
||||
}
|
||||
|
||||
// NameContains applies the Contains predicate on the "name" field.
|
||||
func NameContains(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldContains(FieldName, v))
|
||||
}
|
||||
|
||||
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
|
||||
func NameHasPrefix(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldHasPrefix(FieldName, v))
|
||||
}
|
||||
|
||||
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
|
||||
func NameHasSuffix(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldHasSuffix(FieldName, v))
|
||||
}
|
||||
|
||||
// NameEqualFold applies the EqualFold predicate on the "name" field.
|
||||
func NameEqualFold(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEqualFold(FieldName, v))
|
||||
}
|
||||
|
||||
// NameContainsFold applies the ContainsFold predicate on the "name" field.
|
||||
func NameContainsFold(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldContainsFold(FieldName, v))
|
||||
}
|
||||
|
||||
// DescriptionEQ applies the EQ predicate on the "description" field.
|
||||
func DescriptionEQ(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionNEQ applies the NEQ predicate on the "description" field.
|
||||
func DescriptionNEQ(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionIn applies the In predicate on the "description" field.
|
||||
func DescriptionIn(vs ...string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIn(FieldDescription, vs...))
|
||||
}
|
||||
|
||||
// DescriptionNotIn applies the NotIn predicate on the "description" field.
|
||||
func DescriptionNotIn(vs ...string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldDescription, vs...))
|
||||
}
|
||||
|
||||
// DescriptionGT applies the GT predicate on the "description" field.
|
||||
func DescriptionGT(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGT(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionGTE applies the GTE predicate on the "description" field.
|
||||
func DescriptionGTE(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionLT applies the LT predicate on the "description" field.
|
||||
func DescriptionLT(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLT(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionLTE applies the LTE predicate on the "description" field.
|
||||
func DescriptionLTE(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionContains applies the Contains predicate on the "description" field.
|
||||
func DescriptionContains(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldContains(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
|
||||
func DescriptionHasPrefix(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldHasPrefix(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
|
||||
func DescriptionHasSuffix(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldHasSuffix(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionIsNil applies the IsNil predicate on the "description" field.
|
||||
func DescriptionIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldDescription))
|
||||
}
|
||||
|
||||
// DescriptionNotNil applies the NotNil predicate on the "description" field.
|
||||
func DescriptionNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldDescription))
|
||||
}
|
||||
|
||||
// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
|
||||
func DescriptionEqualFold(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEqualFold(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
|
||||
func DescriptionContainsFold(v string) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldContainsFold(FieldDescription, v))
|
||||
}
|
||||
|
||||
// EnableGreaseEQ applies the EQ predicate on the "enable_grease" field.
|
||||
func EnableGreaseEQ(v bool) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldEnableGrease, v))
|
||||
}
|
||||
|
||||
// EnableGreaseNEQ applies the NEQ predicate on the "enable_grease" field.
|
||||
func EnableGreaseNEQ(v bool) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldEnableGrease, v))
|
||||
}
|
||||
|
||||
// CipherSuitesIsNil applies the IsNil predicate on the "cipher_suites" field.
|
||||
func CipherSuitesIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldCipherSuites))
|
||||
}
|
||||
|
||||
// CipherSuitesNotNil applies the NotNil predicate on the "cipher_suites" field.
|
||||
func CipherSuitesNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldCipherSuites))
|
||||
}
|
||||
|
||||
// CurvesIsNil applies the IsNil predicate on the "curves" field.
|
||||
func CurvesIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldCurves))
|
||||
}
|
||||
|
||||
// CurvesNotNil applies the NotNil predicate on the "curves" field.
|
||||
func CurvesNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldCurves))
|
||||
}
|
||||
|
||||
// PointFormatsIsNil applies the IsNil predicate on the "point_formats" field.
|
||||
func PointFormatsIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldPointFormats))
|
||||
}
|
||||
|
||||
// PointFormatsNotNil applies the NotNil predicate on the "point_formats" field.
|
||||
func PointFormatsNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldPointFormats))
|
||||
}
|
||||
|
||||
// SignatureAlgorithmsIsNil applies the IsNil predicate on the "signature_algorithms" field.
|
||||
func SignatureAlgorithmsIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldSignatureAlgorithms))
|
||||
}
|
||||
|
||||
// SignatureAlgorithmsNotNil applies the NotNil predicate on the "signature_algorithms" field.
|
||||
func SignatureAlgorithmsNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldSignatureAlgorithms))
|
||||
}
|
||||
|
||||
// AlpnProtocolsIsNil applies the IsNil predicate on the "alpn_protocols" field.
|
||||
func AlpnProtocolsIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldAlpnProtocols))
|
||||
}
|
||||
|
||||
// AlpnProtocolsNotNil applies the NotNil predicate on the "alpn_protocols" field.
|
||||
func AlpnProtocolsNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldAlpnProtocols))
|
||||
}
|
||||
|
||||
// SupportedVersionsIsNil applies the IsNil predicate on the "supported_versions" field.
|
||||
func SupportedVersionsIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldSupportedVersions))
|
||||
}
|
||||
|
||||
// SupportedVersionsNotNil applies the NotNil predicate on the "supported_versions" field.
|
||||
func SupportedVersionsNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldSupportedVersions))
|
||||
}
|
||||
|
||||
// KeyShareGroupsIsNil applies the IsNil predicate on the "key_share_groups" field.
|
||||
func KeyShareGroupsIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldKeyShareGroups))
|
||||
}
|
||||
|
||||
// KeyShareGroupsNotNil applies the NotNil predicate on the "key_share_groups" field.
|
||||
func KeyShareGroupsNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldKeyShareGroups))
|
||||
}
|
||||
|
||||
// PskModesIsNil applies the IsNil predicate on the "psk_modes" field.
|
||||
func PskModesIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldPskModes))
|
||||
}
|
||||
|
||||
// PskModesNotNil applies the NotNil predicate on the "psk_modes" field.
|
||||
func PskModesNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldPskModes))
|
||||
}
|
||||
|
||||
// ExtensionsIsNil applies the IsNil predicate on the "extensions" field.
|
||||
func ExtensionsIsNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldExtensions))
|
||||
}
|
||||
|
||||
// ExtensionsNotNil applies the NotNil predicate on the "extensions" field.
|
||||
func ExtensionsNotNil() predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldExtensions))
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.TLSFingerprintProfile) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.TLSFingerprintProfile) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.TLSFingerprintProfile) predicate.TLSFingerprintProfile {
|
||||
return predicate.TLSFingerprintProfile(sql.NotPredicates(p))
|
||||
}
|
||||
1341
backend/ent/tlsfingerprintprofile_create.go
Normal file
1341
backend/ent/tlsfingerprintprofile_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/tlsfingerprintprofile_delete.go
Normal file
88
backend/ent/tlsfingerprintprofile_delete.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfileDelete is the builder for deleting a TLSFingerprintProfile entity.
|
||||
type TLSFingerprintProfileDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *TLSFingerprintProfileMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the TLSFingerprintProfileDelete builder.
|
||||
func (_d *TLSFingerprintProfileDelete) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *TLSFingerprintProfileDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *TLSFingerprintProfileDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *TLSFingerprintProfileDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(tlsfingerprintprofile.Table, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// TLSFingerprintProfileDeleteOne is the builder for deleting a single TLSFingerprintProfile entity.
|
||||
type TLSFingerprintProfileDeleteOne struct {
|
||||
_d *TLSFingerprintProfileDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the TLSFingerprintProfileDelete builder.
|
||||
func (_d *TLSFingerprintProfileDeleteOne) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *TLSFingerprintProfileDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{tlsfingerprintprofile.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *TLSFingerprintProfileDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
564
backend/ent/tlsfingerprintprofile_query.go
Normal file
564
backend/ent/tlsfingerprintprofile_query.go
Normal file
@@ -0,0 +1,564 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfileQuery is the builder for querying TLSFingerprintProfile entities.
|
||||
type TLSFingerprintProfileQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []tlsfingerprintprofile.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.TLSFingerprintProfile
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the TLSFingerprintProfileQuery builder.
|
||||
func (_q *TLSFingerprintProfileQuery) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *TLSFingerprintProfileQuery) Limit(limit int) *TLSFingerprintProfileQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *TLSFingerprintProfileQuery) Offset(offset int) *TLSFingerprintProfileQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *TLSFingerprintProfileQuery) Unique(unique bool) *TLSFingerprintProfileQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *TLSFingerprintProfileQuery) Order(o ...tlsfingerprintprofile.OrderOption) *TLSFingerprintProfileQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// First returns the first TLSFingerprintProfile entity from the query.
|
||||
// Returns a *NotFoundError when no TLSFingerprintProfile was found.
|
||||
func (_q *TLSFingerprintProfileQuery) First(ctx context.Context) (*TLSFingerprintProfile, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{tlsfingerprintprofile.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) FirstX(ctx context.Context) *TLSFingerprintProfile {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first TLSFingerprintProfile ID from the query.
|
||||
// Returns a *NotFoundError when no TLSFingerprintProfile ID was found.
|
||||
func (_q *TLSFingerprintProfileQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{tlsfingerprintprofile.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single TLSFingerprintProfile entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one TLSFingerprintProfile entity is found.
|
||||
// Returns a *NotFoundError when no TLSFingerprintProfile entities are found.
|
||||
func (_q *TLSFingerprintProfileQuery) Only(ctx context.Context) (*TLSFingerprintProfile, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{tlsfingerprintprofile.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{tlsfingerprintprofile.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) OnlyX(ctx context.Context) *TLSFingerprintProfile {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only TLSFingerprintProfile ID in the query.
|
||||
// Returns a *NotSingularError when more than one TLSFingerprintProfile ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *TLSFingerprintProfileQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{tlsfingerprintprofile.Label}
|
||||
default:
|
||||
err = &NotSingularError{tlsfingerprintprofile.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of TLSFingerprintProfiles.
|
||||
func (_q *TLSFingerprintProfileQuery) All(ctx context.Context) ([]*TLSFingerprintProfile, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*TLSFingerprintProfile, *TLSFingerprintProfileQuery]()
|
||||
return withInterceptors[[]*TLSFingerprintProfile](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) AllX(ctx context.Context) []*TLSFingerprintProfile {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of TLSFingerprintProfile IDs.
|
||||
func (_q *TLSFingerprintProfileQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(tlsfingerprintprofile.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *TLSFingerprintProfileQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*TLSFingerprintProfileQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *TLSFingerprintProfileQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *TLSFingerprintProfileQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the TLSFingerprintProfileQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *TLSFingerprintProfileQuery) Clone() *TLSFingerprintProfileQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &TLSFingerprintProfileQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]tlsfingerprintprofile.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.TLSFingerprintProfile{}, _q.predicates...),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.TLSFingerprintProfile.Query().
|
||||
// GroupBy(tlsfingerprintprofile.FieldCreatedAt).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *TLSFingerprintProfileQuery) GroupBy(field string, fields ...string) *TLSFingerprintProfileGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &TLSFingerprintProfileGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = tlsfingerprintprofile.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.TLSFingerprintProfile.Query().
|
||||
// Select(tlsfingerprintprofile.FieldCreatedAt).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *TLSFingerprintProfileQuery) Select(fields ...string) *TLSFingerprintProfileSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &TLSFingerprintProfileSelect{TLSFingerprintProfileQuery: _q}
|
||||
sbuild.label = tlsfingerprintprofile.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a TLSFingerprintProfileSelect configured with the given aggregations.
|
||||
func (_q *TLSFingerprintProfileQuery) Aggregate(fns ...AggregateFunc) *TLSFingerprintProfileSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *TLSFingerprintProfileQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !tlsfingerprintprofile.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *TLSFingerprintProfileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*TLSFingerprintProfile, error) {
|
||||
var (
|
||||
nodes = []*TLSFingerprintProfile{}
|
||||
_spec = _q.querySpec()
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*TLSFingerprintProfile).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &TLSFingerprintProfile{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *TLSFingerprintProfileQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *TLSFingerprintProfileQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(tlsfingerprintprofile.Table, tlsfingerprintprofile.Columns, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, tlsfingerprintprofile.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != tlsfingerprintprofile.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *TLSFingerprintProfileQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(tlsfingerprintprofile.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = tlsfingerprintprofile.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *TLSFingerprintProfileQuery) ForUpdate(opts ...sql.LockOption) *TLSFingerprintProfileQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *TLSFingerprintProfileQuery) ForShare(opts ...sql.LockOption) *TLSFingerprintProfileQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// TLSFingerprintProfileGroupBy is the group-by builder for TLSFingerprintProfile entities.
|
||||
type TLSFingerprintProfileGroupBy struct {
|
||||
selector
|
||||
build *TLSFingerprintProfileQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *TLSFingerprintProfileGroupBy) Aggregate(fns ...AggregateFunc) *TLSFingerprintProfileGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *TLSFingerprintProfileGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*TLSFingerprintProfileQuery, *TLSFingerprintProfileGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *TLSFingerprintProfileGroupBy) sqlScan(ctx context.Context, root *TLSFingerprintProfileQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// TLSFingerprintProfileSelect is the builder for selecting fields of TLSFingerprintProfile entities.
|
||||
type TLSFingerprintProfileSelect struct {
|
||||
*TLSFingerprintProfileQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *TLSFingerprintProfileSelect) Aggregate(fns ...AggregateFunc) *TLSFingerprintProfileSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *TLSFingerprintProfileSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*TLSFingerprintProfileQuery, *TLSFingerprintProfileSelect](ctx, _s.TLSFingerprintProfileQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *TLSFingerprintProfileSelect) sqlScan(ctx context.Context, root *TLSFingerprintProfileQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
881
backend/ent/tlsfingerprintprofile_update.go
Normal file
881
backend/ent/tlsfingerprintprofile_update.go
Normal file
@@ -0,0 +1,881 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfileUpdate is the builder for updating TLSFingerprintProfile entities.
|
||||
type TLSFingerprintProfileUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *TLSFingerprintProfileMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the TLSFingerprintProfileUpdate builder.
|
||||
func (_u *TLSFingerprintProfileUpdate) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetName sets the "name" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetName(v string) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetName(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableName sets the "name" field if the given value is not nil.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetNillableName(v *string) *TLSFingerprintProfileUpdate {
|
||||
if v != nil {
|
||||
_u.SetName(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetDescription(v string) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetDescription(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDescription sets the "description" field if the given value is not nil.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetNillableDescription(v *string) *TLSFingerprintProfileUpdate {
|
||||
if v != nil {
|
||||
_u.SetDescription(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDescription clears the value of the "description" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearDescription() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearDescription()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetEnableGrease sets the "enable_grease" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetEnableGrease(v bool) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetEnableGrease(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableEnableGrease sets the "enable_grease" field if the given value is not nil.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetNillableEnableGrease(v *bool) *TLSFingerprintProfileUpdate {
|
||||
if v != nil {
|
||||
_u.SetEnableGrease(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCipherSuites sets the "cipher_suites" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetCipherSuites(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendCipherSuites appends value to the "cipher_suites" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendCipherSuites(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendCipherSuites(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCipherSuites clears the value of the "cipher_suites" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearCipherSuites() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearCipherSuites()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCurves sets the "curves" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetCurves(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetCurves(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendCurves appends value to the "curves" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendCurves(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendCurves(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCurves clears the value of the "curves" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearCurves() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearCurves()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPointFormats sets the "point_formats" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetPointFormats(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetPointFormats(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendPointFormats appends value to the "point_formats" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendPointFormats(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendPointFormats(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearPointFormats clears the value of the "point_formats" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearPointFormats() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearPointFormats()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSignatureAlgorithms sets the "signature_algorithms" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetSignatureAlgorithms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendSignatureAlgorithms appends value to the "signature_algorithms" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendSignatureAlgorithms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearSignatureAlgorithms() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearSignatureAlgorithms()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAlpnProtocols sets the "alpn_protocols" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetAlpnProtocols(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendAlpnProtocols appends value to the "alpn_protocols" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendAlpnProtocols(v []string) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendAlpnProtocols(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearAlpnProtocols clears the value of the "alpn_protocols" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearAlpnProtocols() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearAlpnProtocols()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSupportedVersions sets the "supported_versions" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetSupportedVersions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendSupportedVersions appends value to the "supported_versions" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendSupportedVersions(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendSupportedVersions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSupportedVersions clears the value of the "supported_versions" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearSupportedVersions() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearSupportedVersions()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetKeyShareGroups sets the "key_share_groups" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetKeyShareGroups(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendKeyShareGroups appends value to the "key_share_groups" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendKeyShareGroups(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearKeyShareGroups clears the value of the "key_share_groups" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearKeyShareGroups() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearKeyShareGroups()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPskModes sets the "psk_modes" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetPskModes(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetPskModes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendPskModes appends value to the "psk_modes" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendPskModes(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendPskModes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearPskModes clears the value of the "psk_modes" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearPskModes() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearPskModes()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExtensions sets the "extensions" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) SetExtensions(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.SetExtensions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendExtensions appends value to the "extensions" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) AppendExtensions(v []uint16) *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.AppendExtensions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExtensions clears the value of the "extensions" field.
|
||||
func (_u *TLSFingerprintProfileUpdate) ClearExtensions() *TLSFingerprintProfileUpdate {
|
||||
_u.mutation.ClearExtensions()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Mutation returns the TLSFingerprintProfileMutation object of the builder.
|
||||
func (_u *TLSFingerprintProfileUpdate) Mutation() *TLSFingerprintProfileMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *TLSFingerprintProfileUpdate) Save(ctx context.Context) (int, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *TLSFingerprintProfileUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *TLSFingerprintProfileUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *TLSFingerprintProfileUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *TLSFingerprintProfileUpdate) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := tlsfingerprintprofile.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *TLSFingerprintProfileUpdate) check() error {
|
||||
if v, ok := _u.mutation.Name(); ok {
|
||||
if err := tlsfingerprintprofile.NameValidator(v); err != nil {
|
||||
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "TLSFingerprintProfile.name": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *TLSFingerprintProfileUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(tlsfingerprintprofile.Table, tlsfingerprintprofile.Columns, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Name(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldName, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.DescriptionCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldDescription, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.EnableGrease(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldEnableGrease, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.CipherSuites(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedCipherSuites(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldCipherSuites, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.CipherSuitesCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.Curves(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldCurves, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedCurves(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldCurves, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.CurvesCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldCurves, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.PointFormats(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedPointFormats(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldPointFormats, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.PointFormatsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.SignatureAlgorithms(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedSignatureAlgorithms(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldSignatureAlgorithms, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.SignatureAlgorithmsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.AlpnProtocols(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedAlpnProtocols(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldAlpnProtocols, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.AlpnProtocolsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.SupportedVersions(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedSupportedVersions(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldSupportedVersions, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.SupportedVersionsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.KeyShareGroups(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedKeyShareGroups(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldKeyShareGroups, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.KeyShareGroupsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.PskModes(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedPskModes(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldPskModes, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.PskModesCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.Extensions(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedExtensions(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldExtensions, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.ExtensionsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{tlsfingerprintprofile.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// TLSFingerprintProfileUpdateOne is the builder for updating a single TLSFingerprintProfile entity.
|
||||
type TLSFingerprintProfileUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *TLSFingerprintProfileMutation
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetName sets the "name" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetName(v string) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetName(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableName sets the "name" field if the given value is not nil.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetNillableName(v *string) *TLSFingerprintProfileUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetName(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetDescription(v string) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetDescription(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDescription sets the "description" field if the given value is not nil.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetNillableDescription(v *string) *TLSFingerprintProfileUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDescription(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDescription clears the value of the "description" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearDescription() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearDescription()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetEnableGrease sets the "enable_grease" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetEnableGrease(v bool) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetEnableGrease(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableEnableGrease sets the "enable_grease" field if the given value is not nil.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetNillableEnableGrease(v *bool) *TLSFingerprintProfileUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetEnableGrease(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCipherSuites sets the "cipher_suites" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetCipherSuites(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendCipherSuites appends value to the "cipher_suites" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendCipherSuites(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendCipherSuites(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCipherSuites clears the value of the "cipher_suites" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearCipherSuites() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearCipherSuites()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCurves sets the "curves" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetCurves(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetCurves(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendCurves appends value to the "curves" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendCurves(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendCurves(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCurves clears the value of the "curves" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearCurves() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearCurves()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPointFormats sets the "point_formats" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetPointFormats(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetPointFormats(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendPointFormats appends value to the "point_formats" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendPointFormats(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendPointFormats(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearPointFormats clears the value of the "point_formats" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearPointFormats() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearPointFormats()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSignatureAlgorithms sets the "signature_algorithms" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetSignatureAlgorithms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendSignatureAlgorithms appends value to the "signature_algorithms" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendSignatureAlgorithms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearSignatureAlgorithms() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearSignatureAlgorithms()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAlpnProtocols sets the "alpn_protocols" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetAlpnProtocols(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendAlpnProtocols appends value to the "alpn_protocols" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendAlpnProtocols(v []string) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendAlpnProtocols(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearAlpnProtocols clears the value of the "alpn_protocols" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearAlpnProtocols() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearAlpnProtocols()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSupportedVersions sets the "supported_versions" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetSupportedVersions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendSupportedVersions appends value to the "supported_versions" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendSupportedVersions(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendSupportedVersions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearSupportedVersions clears the value of the "supported_versions" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearSupportedVersions() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearSupportedVersions()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetKeyShareGroups sets the "key_share_groups" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetKeyShareGroups(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendKeyShareGroups appends value to the "key_share_groups" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendKeyShareGroups(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearKeyShareGroups clears the value of the "key_share_groups" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearKeyShareGroups() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearKeyShareGroups()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPskModes sets the "psk_modes" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetPskModes(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetPskModes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendPskModes appends value to the "psk_modes" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendPskModes(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendPskModes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearPskModes clears the value of the "psk_modes" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearPskModes() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearPskModes()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExtensions sets the "extensions" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SetExtensions(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.SetExtensions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendExtensions appends value to the "extensions" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) AppendExtensions(v []uint16) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.AppendExtensions(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExtensions clears the value of the "extensions" field.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ClearExtensions() *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.ClearExtensions()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Mutation returns the TLSFingerprintProfileMutation object of the builder.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) Mutation() *TLSFingerprintProfileMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the TLSFingerprintProfileUpdate builder.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) Select(field string, fields ...string) *TLSFingerprintProfileUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated TLSFingerprintProfile entity.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) Save(ctx context.Context) (*TLSFingerprintProfile, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) SaveX(ctx context.Context) *TLSFingerprintProfile {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := tlsfingerprintprofile.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *TLSFingerprintProfileUpdateOne) check() error {
|
||||
if v, ok := _u.mutation.Name(); ok {
|
||||
if err := tlsfingerprintprofile.NameValidator(v); err != nil {
|
||||
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "TLSFingerprintProfile.name": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *TLSFingerprintProfileUpdateOne) sqlSave(ctx context.Context) (_node *TLSFingerprintProfile, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(tlsfingerprintprofile.Table, tlsfingerprintprofile.Columns, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "TLSFingerprintProfile.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, tlsfingerprintprofile.FieldID)
|
||||
for _, f := range fields {
|
||||
if !tlsfingerprintprofile.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != tlsfingerprintprofile.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Name(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldName, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.DescriptionCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldDescription, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.EnableGrease(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldEnableGrease, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.CipherSuites(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedCipherSuites(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldCipherSuites, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.CipherSuitesCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.Curves(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldCurves, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedCurves(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldCurves, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.CurvesCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldCurves, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.PointFormats(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedPointFormats(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldPointFormats, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.PointFormatsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.SignatureAlgorithms(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedSignatureAlgorithms(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldSignatureAlgorithms, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.SignatureAlgorithmsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.AlpnProtocols(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedAlpnProtocols(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldAlpnProtocols, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.AlpnProtocolsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.SupportedVersions(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedSupportedVersions(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldSupportedVersions, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.SupportedVersionsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.KeyShareGroups(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedKeyShareGroups(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldKeyShareGroups, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.KeyShareGroupsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.PskModes(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedPskModes(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldPskModes, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.PskModesCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.Extensions(); ok {
|
||||
_spec.SetField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedExtensions(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, tlsfingerprintprofile.FieldExtensions, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.ExtensionsCleared() {
|
||||
_spec.ClearField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON)
|
||||
}
|
||||
_node = &TLSFingerprintProfile{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{tlsfingerprintprofile.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
@@ -42,6 +42,8 @@ type Tx struct {
|
||||
SecuritySecret *SecuritySecretClient
|
||||
// Setting is the client for interacting with the Setting builders.
|
||||
Setting *SettingClient
|
||||
// TLSFingerprintProfile is the client for interacting with the TLSFingerprintProfile builders.
|
||||
TLSFingerprintProfile *TLSFingerprintProfileClient
|
||||
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
|
||||
UsageCleanupTask *UsageCleanupTaskClient
|
||||
// UsageLog is the client for interacting with the UsageLog builders.
|
||||
@@ -201,6 +203,7 @@ func (tx *Tx) init() {
|
||||
tx.RedeemCode = NewRedeemCodeClient(tx.config)
|
||||
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
|
||||
tx.Setting = NewSettingClient(tx.config)
|
||||
tx.TLSFingerprintProfile = NewTLSFingerprintProfileClient(tx.config)
|
||||
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
|
||||
tx.UsageLog = NewUsageLogClient(tx.config)
|
||||
tx.User = NewUserClient(tx.config)
|
||||
|
||||
@@ -32,6 +32,18 @@ type UsageLog struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
// Model holds the value of the "model" field.
|
||||
Model string `json:"model,omitempty"`
|
||||
// RequestedModel holds the value of the "requested_model" field.
|
||||
RequestedModel *string `json:"requested_model,omitempty"`
|
||||
// UpstreamModel holds the value of the "upstream_model" field.
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
// 渠道 ID
|
||||
ChannelID *int64 `json:"channel_id,omitempty"`
|
||||
// 模型映射链
|
||||
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
|
||||
// 计费层级标签
|
||||
BillingTier *string `json:"billing_tier,omitempty"`
|
||||
// 计费模式:token/per_request/image
|
||||
BillingMode *string `json:"billing_mode,omitempty"`
|
||||
// GroupID holds the value of the "group_id" field.
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
// SubscriptionID holds the value of the "subscription_id" field.
|
||||
@@ -80,8 +92,6 @@ type UsageLog struct {
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
// ImageSize holds the value of the "image_size" field.
|
||||
ImageSize *string `json:"image_size,omitempty"`
|
||||
// MediaType holds the value of the "media_type" field.
|
||||
MediaType *string `json:"media_type,omitempty"`
|
||||
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
@@ -173,9 +183,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -230,6 +240,48 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Model = value.String
|
||||
}
|
||||
case usagelog.FieldRequestedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field requested_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RequestedModel = new(string)
|
||||
*_m.RequestedModel = value.String
|
||||
}
|
||||
case usagelog.FieldUpstreamModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpstreamModel = new(string)
|
||||
*_m.UpstreamModel = value.String
|
||||
}
|
||||
case usagelog.FieldChannelID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field channel_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ChannelID = new(int64)
|
||||
*_m.ChannelID = value.Int64
|
||||
}
|
||||
case usagelog.FieldModelMappingChain:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ModelMappingChain = new(string)
|
||||
*_m.ModelMappingChain = value.String
|
||||
}
|
||||
case usagelog.FieldBillingTier:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field billing_tier", values[i])
|
||||
} else if value.Valid {
|
||||
_m.BillingTier = new(string)
|
||||
*_m.BillingTier = value.String
|
||||
}
|
||||
case usagelog.FieldBillingMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field billing_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.BillingMode = new(string)
|
||||
*_m.BillingMode = value.String
|
||||
}
|
||||
case usagelog.FieldGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||
@@ -382,13 +434,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.ImageSize = new(string)
|
||||
*_m.ImageSize = value.String
|
||||
}
|
||||
case usagelog.FieldMediaType:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field media_type", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MediaType = new(string)
|
||||
*_m.MediaType = value.String
|
||||
}
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
||||
@@ -477,6 +522,36 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString("model=")
|
||||
builder.WriteString(_m.Model)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.RequestedModel; v != nil {
|
||||
builder.WriteString("requested_model=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.UpstreamModel; v != nil {
|
||||
builder.WriteString("upstream_model=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ChannelID; v != nil {
|
||||
builder.WriteString("channel_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ModelMappingChain; v != nil {
|
||||
builder.WriteString("model_mapping_chain=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.BillingTier; v != nil {
|
||||
builder.WriteString("billing_tier=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.BillingMode; v != nil {
|
||||
builder.WriteString("billing_mode=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.GroupID; v != nil {
|
||||
builder.WriteString("group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
@@ -565,11 +640,6 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.MediaType; v != nil {
|
||||
builder.WriteString("media_type=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("cache_ttl_overridden=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -24,6 +24,18 @@ const (
|
||||
FieldRequestID = "request_id"
|
||||
// FieldModel holds the string denoting the model field in the database.
|
||||
FieldModel = "model"
|
||||
// FieldRequestedModel holds the string denoting the requested_model field in the database.
|
||||
FieldRequestedModel = "requested_model"
|
||||
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||
FieldUpstreamModel = "upstream_model"
|
||||
// FieldChannelID holds the string denoting the channel_id field in the database.
|
||||
FieldChannelID = "channel_id"
|
||||
// FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database.
|
||||
FieldModelMappingChain = "model_mapping_chain"
|
||||
// FieldBillingTier holds the string denoting the billing_tier field in the database.
|
||||
FieldBillingTier = "billing_tier"
|
||||
// FieldBillingMode holds the string denoting the billing_mode field in the database.
|
||||
FieldBillingMode = "billing_mode"
|
||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||
FieldGroupID = "group_id"
|
||||
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
|
||||
@@ -72,8 +84,6 @@ const (
|
||||
FieldImageCount = "image_count"
|
||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||
FieldImageSize = "image_size"
|
||||
// FieldMediaType holds the string denoting the media_type field in the database.
|
||||
FieldMediaType = "media_type"
|
||||
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
||||
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
@@ -135,6 +145,12 @@ var Columns = []string{
|
||||
FieldAccountID,
|
||||
FieldRequestID,
|
||||
FieldModel,
|
||||
FieldRequestedModel,
|
||||
FieldUpstreamModel,
|
||||
FieldChannelID,
|
||||
FieldModelMappingChain,
|
||||
FieldBillingTier,
|
||||
FieldBillingMode,
|
||||
FieldGroupID,
|
||||
FieldSubscriptionID,
|
||||
FieldInputTokens,
|
||||
@@ -159,7 +175,6 @@ var Columns = []string{
|
||||
FieldIPAddress,
|
||||
FieldImageCount,
|
||||
FieldImageSize,
|
||||
FieldMediaType,
|
||||
FieldCacheTTLOverridden,
|
||||
FieldCreatedAt,
|
||||
}
|
||||
@@ -179,6 +194,16 @@ var (
|
||||
RequestIDValidator func(string) error
|
||||
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
||||
ModelValidator func(string) error
|
||||
// RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
|
||||
RequestedModelValidator func(string) error
|
||||
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
UpstreamModelValidator func(string) error
|
||||
// ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
|
||||
ModelMappingChainValidator func(string) error
|
||||
// BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
|
||||
BillingTierValidator func(string) error
|
||||
// BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
|
||||
BillingModeValidator func(string) error
|
||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||
DefaultInputTokens int
|
||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||
@@ -217,8 +242,6 @@ var (
|
||||
DefaultImageCount int
|
||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
ImageSizeValidator func(string) error
|
||||
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
MediaTypeValidator func(string) error
|
||||
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
||||
DefaultCacheTTLOverridden bool
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
@@ -258,6 +281,36 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRequestedModel orders the results by the requested_model field.
|
||||
func ByRequestedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRequestedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpstreamModel orders the results by the upstream_model field.
|
||||
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByChannelID orders the results by the channel_id field.
|
||||
func ByChannelID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldChannelID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByModelMappingChain orders the results by the model_mapping_chain field.
|
||||
func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBillingTier orders the results by the billing_tier field.
|
||||
func ByBillingTier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBillingTier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBillingMode orders the results by the billing_mode field.
|
||||
func ByBillingMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBillingMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByGroupID orders the results by the group_id field.
|
||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||
@@ -378,11 +431,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMediaType orders the results by the media_type field.
|
||||
func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
||||
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
||||
|
||||
@@ -80,6 +80,36 @@ func Model(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
||||
}
|
||||
|
||||
// RequestedModel applies equality check predicate on the "requested_model" field. It's identical to RequestedModelEQ.
|
||||
func RequestedModel(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
|
||||
func UpstreamModel(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ.
|
||||
func ChannelID(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ.
|
||||
func ModelMappingChain(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ.
|
||||
func BillingTier(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ.
|
||||
func BillingMode(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||
func GroupID(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||
@@ -200,11 +230,6 @@ func ImageSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
|
||||
func MediaType(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
||||
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
@@ -405,6 +430,431 @@ func ModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelEQ applies the EQ predicate on the "requested_model" field.
|
||||
func RequestedModelEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelNEQ applies the NEQ predicate on the "requested_model" field.
|
||||
func RequestedModelNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelIn applies the In predicate on the "requested_model" field.
|
||||
func RequestedModelIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldRequestedModel, vs...))
|
||||
}
|
||||
|
||||
// RequestedModelNotIn applies the NotIn predicate on the "requested_model" field.
|
||||
func RequestedModelNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldRequestedModel, vs...))
|
||||
}
|
||||
|
||||
// RequestedModelGT applies the GT predicate on the "requested_model" field.
|
||||
func RequestedModelGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelGTE applies the GTE predicate on the "requested_model" field.
|
||||
func RequestedModelGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelLT applies the LT predicate on the "requested_model" field.
|
||||
func RequestedModelLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelLTE applies the LTE predicate on the "requested_model" field.
|
||||
func RequestedModelLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelContains applies the Contains predicate on the "requested_model" field.
|
||||
func RequestedModelContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelHasPrefix applies the HasPrefix predicate on the "requested_model" field.
|
||||
func RequestedModelHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelHasSuffix applies the HasSuffix predicate on the "requested_model" field.
|
||||
func RequestedModelHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelIsNil applies the IsNil predicate on the "requested_model" field.
|
||||
func RequestedModelIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldRequestedModel))
|
||||
}
|
||||
|
||||
// RequestedModelNotNil applies the NotNil predicate on the "requested_model" field.
|
||||
func RequestedModelNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldRequestedModel))
|
||||
}
|
||||
|
||||
// RequestedModelEqualFold applies the EqualFold predicate on the "requested_model" field.
|
||||
func RequestedModelEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelContainsFold applies the ContainsFold predicate on the "requested_model" field.
|
||||
func RequestedModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
|
||||
func UpstreamModelEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
|
||||
func UpstreamModelNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
|
||||
func UpstreamModelIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
|
||||
}
|
||||
|
||||
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
|
||||
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
|
||||
}
|
||||
|
||||
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
|
||||
func UpstreamModelGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
|
||||
func UpstreamModelGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
|
||||
func UpstreamModelLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
|
||||
func UpstreamModelLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
|
||||
func UpstreamModelContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
|
||||
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
|
||||
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
|
||||
func UpstreamModelIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
|
||||
}
|
||||
|
||||
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
|
||||
func UpstreamModelNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
|
||||
}
|
||||
|
||||
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
|
||||
func UpstreamModelEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
|
||||
func UpstreamModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
|
||||
}
|
||||
|
||||
// ChannelIDEQ applies the EQ predicate on the "channel_id" field.
|
||||
func ChannelIDEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field.
|
||||
func ChannelIDNEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDIn applies the In predicate on the "channel_id" field.
|
||||
func ChannelIDIn(vs ...int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...))
|
||||
}
|
||||
|
||||
// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field.
|
||||
func ChannelIDNotIn(vs ...int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...))
|
||||
}
|
||||
|
||||
// ChannelIDGT applies the GT predicate on the "channel_id" field.
|
||||
func ChannelIDGT(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDGTE applies the GTE predicate on the "channel_id" field.
|
||||
func ChannelIDGTE(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDLT applies the LT predicate on the "channel_id" field.
|
||||
func ChannelIDLT(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDLTE applies the LTE predicate on the "channel_id" field.
|
||||
func ChannelIDLTE(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v))
|
||||
}
|
||||
|
||||
// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field.
|
||||
func ChannelIDIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldChannelID))
|
||||
}
|
||||
|
||||
// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field.
|
||||
func ChannelIDNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldChannelID))
|
||||
}
|
||||
|
||||
// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...))
|
||||
}
|
||||
|
||||
// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...))
|
||||
}
|
||||
|
||||
// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain))
|
||||
}
|
||||
|
||||
// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain))
|
||||
}
|
||||
|
||||
// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field.
|
||||
func ModelMappingChainContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v))
|
||||
}
|
||||
|
||||
// BillingTierEQ applies the EQ predicate on the "billing_tier" field.
|
||||
func BillingTierEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field.
|
||||
func BillingTierNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierIn applies the In predicate on the "billing_tier" field.
|
||||
func BillingTierIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...))
|
||||
}
|
||||
|
||||
// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field.
|
||||
func BillingTierNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...))
|
||||
}
|
||||
|
||||
// BillingTierGT applies the GT predicate on the "billing_tier" field.
|
||||
func BillingTierGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierGTE applies the GTE predicate on the "billing_tier" field.
|
||||
func BillingTierGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierLT applies the LT predicate on the "billing_tier" field.
|
||||
func BillingTierLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierLTE applies the LTE predicate on the "billing_tier" field.
|
||||
func BillingTierLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierContains applies the Contains predicate on the "billing_tier" field.
|
||||
func BillingTierContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field.
|
||||
func BillingTierHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field.
|
||||
func BillingTierHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field.
|
||||
func BillingTierIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier))
|
||||
}
|
||||
|
||||
// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field.
|
||||
func BillingTierNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier))
|
||||
}
|
||||
|
||||
// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field.
|
||||
func BillingTierEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field.
|
||||
func BillingTierContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v))
|
||||
}
|
||||
|
||||
// BillingModeEQ applies the EQ predicate on the "billing_mode" field.
|
||||
func BillingModeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field.
|
||||
func BillingModeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeIn applies the In predicate on the "billing_mode" field.
|
||||
func BillingModeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...))
|
||||
}
|
||||
|
||||
// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field.
|
||||
func BillingModeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...))
|
||||
}
|
||||
|
||||
// BillingModeGT applies the GT predicate on the "billing_mode" field.
|
||||
func BillingModeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeGTE applies the GTE predicate on the "billing_mode" field.
|
||||
func BillingModeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeLT applies the LT predicate on the "billing_mode" field.
|
||||
func BillingModeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeLTE applies the LTE predicate on the "billing_mode" field.
|
||||
func BillingModeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeContains applies the Contains predicate on the "billing_mode" field.
|
||||
func BillingModeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field.
|
||||
func BillingModeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field.
|
||||
func BillingModeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field.
|
||||
func BillingModeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode))
|
||||
}
|
||||
|
||||
// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field.
|
||||
func BillingModeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode))
|
||||
}
|
||||
|
||||
// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field.
|
||||
func BillingModeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field.
|
||||
func BillingModeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v))
|
||||
}
|
||||
|
||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||
func GroupIDEQ(v int64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
|
||||
@@ -1450,81 +1900,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// MediaTypeEQ applies the EQ predicate on the "media_type" field.
|
||||
func MediaTypeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
|
||||
func MediaTypeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeIn applies the In predicate on the "media_type" field.
|
||||
func MediaTypeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
|
||||
}
|
||||
|
||||
// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
|
||||
func MediaTypeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
|
||||
}
|
||||
|
||||
// MediaTypeGT applies the GT predicate on the "media_type" field.
|
||||
func MediaTypeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeGTE applies the GTE predicate on the "media_type" field.
|
||||
func MediaTypeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeLT applies the LT predicate on the "media_type" field.
|
||||
func MediaTypeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeLTE applies the LTE predicate on the "media_type" field.
|
||||
func MediaTypeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeContains applies the Contains predicate on the "media_type" field.
|
||||
func MediaTypeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
|
||||
func MediaTypeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
|
||||
func MediaTypeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
|
||||
func MediaTypeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
|
||||
}
|
||||
|
||||
// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
|
||||
func MediaTypeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
|
||||
}
|
||||
|
||||
// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
|
||||
func MediaTypeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
|
||||
func MediaTypeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
||||
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
|
||||
@@ -57,6 +57,90 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (_c *UsageLogCreate) SetRequestedModel(v string) *UsageLogCreate {
|
||||
_c.mutation.SetRequestedModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableRequestedModel(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetRequestedModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
|
||||
_c.mutation.SetUpstreamModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetUpstreamModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate {
|
||||
_c.mutation.SetChannelID(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetChannelID(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate {
|
||||
_c.mutation.SetModelMappingChain(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetModelMappingChain(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate {
|
||||
_c.mutation.SetBillingTier(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetBillingTier(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate {
|
||||
_c.mutation.SetBillingMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetBillingMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
|
||||
_c.mutation.SetGroupID(v)
|
||||
@@ -393,20 +477,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetMediaType sets the "media_type" field.
|
||||
func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
|
||||
_c.mutation.SetMediaType(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetMediaType(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
||||
_c.mutation.SetCacheTTLOverridden(v)
|
||||
@@ -596,6 +666,31 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.RequestedModel(); ok {
|
||||
if err := usagelog.RequestedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ModelMappingChain(); ok {
|
||||
if err := usagelog.ModelMappingChainValidator(v); err != nil {
|
||||
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.BillingTier(); ok {
|
||||
if err := usagelog.BillingTierValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.BillingMode(); ok {
|
||||
if err := usagelog.BillingModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.InputTokens(); !ok {
|
||||
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
|
||||
}
|
||||
@@ -659,11 +754,6 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.MediaType(); ok {
|
||||
if err := usagelog.MediaTypeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
||||
}
|
||||
@@ -714,6 +804,30 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
_node.Model = value
|
||||
}
|
||||
if value, ok := _c.mutation.RequestedModel(); ok {
|
||||
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
|
||||
_node.RequestedModel = &value
|
||||
}
|
||||
if value, ok := _c.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
_node.UpstreamModel = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ChannelID(); ok {
|
||||
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
_node.ChannelID = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ModelMappingChain(); ok {
|
||||
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
|
||||
_node.ModelMappingChain = &value
|
||||
}
|
||||
if value, ok := _c.mutation.BillingTier(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
|
||||
_node.BillingTier = &value
|
||||
}
|
||||
if value, ok := _c.mutation.BillingMode(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
|
||||
_node.BillingMode = &value
|
||||
}
|
||||
if value, ok := _c.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
_node.InputTokens = value
|
||||
@@ -802,10 +916,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||
_node.ImageSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.MediaType(); ok {
|
||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
||||
_node.MediaType = &value
|
||||
}
|
||||
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
_node.CacheTTLOverridden = value
|
||||
@@ -1011,6 +1121,120 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (u *UsageLogUpsert) SetRequestedModel(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldRequestedModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateRequestedModel() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldRequestedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (u *UsageLogUpsert) ClearRequestedModel() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldRequestedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldUpstreamModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldUpstreamModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldUpstreamModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldChannelID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldChannelID)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddChannelID adds v to the "channel_id" field.
|
||||
func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert {
|
||||
u.Add(usagelog.FieldChannelID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldChannelID)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldModelMappingChain, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldModelMappingChain)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldModelMappingChain)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldBillingTier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldBillingTier)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldBillingTier)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldBillingMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldBillingMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldBillingMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldGroupID, v)
|
||||
@@ -1455,24 +1679,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetMediaType sets the "media_type" field.
|
||||
func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldMediaType, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldMediaType)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearMediaType clears the value of the "media_type" field.
|
||||
func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldMediaType)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
||||
@@ -1600,6 +1806,139 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (u *UsageLogUpsertOne) SetRequestedModel(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetRequestedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateRequestedModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (u *UsageLogUpsertOne) ClearRequestedModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetUpstreamModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddChannelID adds v to the "channel_id" field.
|
||||
func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetModelMappingChain(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingTier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2118,27 +2457,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetMediaType sets the "media_type" field.
|
||||
func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetMediaType(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateMediaType()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearMediaType clears the value of the "media_type" field.
|
||||
func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearMediaType()
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2434,6 +2752,139 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (u *UsageLogUpsertBulk) SetRequestedModel(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetRequestedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateRequestedModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (u *UsageLogUpsertBulk) ClearRequestedModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetUpstreamModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearUpstreamModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddChannelID adds v to the "channel_id" field.
|
||||
func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddChannelID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearChannelID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetModelMappingChain(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearModelMappingChain()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingTier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingTier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetBillingMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearBillingMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2952,27 +3403,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetMediaType sets the "media_type" field.
|
||||
func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetMediaType(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateMediaType()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearMediaType clears the value of the "media_type" field.
|
||||
func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearMediaType()
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@@ -102,6 +102,133 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (_u *UsageLogUpdate) SetRequestedModel(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetRequestedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableRequestedModel(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetRequestedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (_u *UsageLogUpdate) ClearRequestedModel() *UsageLogUpdate {
|
||||
_u.mutation.ClearRequestedModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetUpstreamModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetUpstreamModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
|
||||
_u.mutation.ClearUpstreamModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.ResetChannelID()
|
||||
_u.mutation.SetChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetChannelID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddChannelID adds value to the "channel_id" field.
|
||||
func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.AddChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate {
|
||||
_u.mutation.ClearChannelID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetModelMappingChain(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetModelMappingChain(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate {
|
||||
_u.mutation.ClearModelMappingChain()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetBillingTier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetBillingTier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate {
|
||||
_u.mutation.ClearBillingTier()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetBillingMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetBillingMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate {
|
||||
_u.mutation.ClearBillingMode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@@ -612,26 +739,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMediaType sets the "media_type" field.
|
||||
func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetMediaType(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetMediaType(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMediaType clears the value of the "media_type" field.
|
||||
func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
|
||||
_u.mutation.ClearMediaType()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
@@ -745,6 +852,31 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.RequestedModel(); ok {
|
||||
if err := usagelog.RequestedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
if err := usagelog.ModelMappingChainValidator(v); err != nil {
|
||||
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingTier(); ok {
|
||||
if err := usagelog.BillingTierValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingMode(); ok {
|
||||
if err := usagelog.BillingModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
@@ -760,11 +892,6 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.MediaType(); ok {
|
||||
if err := usagelog.MediaTypeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@@ -795,6 +922,45 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.Model(); ok {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequestedModel(); ok {
|
||||
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.RequestedModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.UpstreamModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ChannelID(); ok {
|
||||
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedChannelID(); ok {
|
||||
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.ChannelIDCleared() {
|
||||
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ModelMappingChainCleared() {
|
||||
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingTier(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingTierCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingMode(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingModeCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
}
|
||||
@@ -933,12 +1099,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.MediaType(); ok {
|
||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.MediaTypeCleared() {
|
||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
@@ -1177,6 +1337,133 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (_u *UsageLogUpdateOne) SetRequestedModel(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetRequestedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableRequestedModel(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRequestedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (_u *UsageLogUpdateOne) ClearRequestedModel() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearRequestedModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetUpstreamModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUpstreamModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUpstreamModel clears the value of the "upstream_model" field.
|
||||
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearUpstreamModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetChannelID sets the "channel_id" field.
|
||||
func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.ResetChannelID()
|
||||
_u.mutation.SetChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetChannelID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddChannelID adds value to the "channel_id" field.
|
||||
func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.AddChannelID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearChannelID clears the value of the "channel_id" field.
|
||||
func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearChannelID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelMappingChain sets the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetModelMappingChain(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetModelMappingChain(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
|
||||
func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearModelMappingChain()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingTier sets the "billing_tier" field.
|
||||
func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetBillingTier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetBillingTier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingTier clears the value of the "billing_tier" field.
|
||||
func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearBillingTier()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingMode sets the "billing_mode" field.
|
||||
func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetBillingMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetBillingMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearBillingMode clears the value of the "billing_mode" field.
|
||||
func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearBillingMode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@@ -1687,26 +1974,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMediaType sets the "media_type" field.
|
||||
func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetMediaType(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMediaType(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMediaType clears the value of the "media_type" field.
|
||||
func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearMediaType()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
@@ -1833,6 +2100,31 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.RequestedModel(); ok {
|
||||
if err := usagelog.RequestedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
if err := usagelog.ModelMappingChainValidator(v); err != nil {
|
||||
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingTier(); ok {
|
||||
if err := usagelog.BillingTierValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.BillingMode(); ok {
|
||||
if err := usagelog.BillingModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
@@ -1848,11 +2140,6 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.MediaType(); ok {
|
||||
if err := usagelog.MediaTypeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@@ -1900,6 +2187,45 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if value, ok := _u.mutation.Model(); ok {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequestedModel(); ok {
|
||||
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.RequestedModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.UpstreamModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ChannelID(); ok {
|
||||
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedChannelID(); ok {
|
||||
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.ChannelIDCleared() {
|
||||
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelMappingChain(); ok {
|
||||
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ModelMappingChainCleared() {
|
||||
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingTier(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingTierCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingMode(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.BillingModeCleared() {
|
||||
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.InputTokens(); ok {
|
||||
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
|
||||
}
|
||||
@@ -2038,12 +2364,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.MediaType(); ok {
|
||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.MediaTypeCleared() {
|
||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
|
||||
@@ -45,10 +45,6 @@ type User struct {
|
||||
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||
// SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
|
||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the UserQuery when eager-loading is set.
|
||||
Edges UserEdges `json:"edges"`
|
||||
@@ -181,7 +177,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case user.FieldBalance:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
|
||||
case user.FieldID, user.FieldConcurrency:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -295,18 +291,6 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
||||
_m.TotpEnabledAt = new(time.Time)
|
||||
*_m.TotpEnabledAt = value.Time
|
||||
}
|
||||
case user.FieldSoraStorageQuotaBytes:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraStorageQuotaBytes = value.Int64
|
||||
}
|
||||
case user.FieldSoraStorageUsedBytes:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SoraStorageUsedBytes = value.Int64
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -440,12 +424,6 @@ func (_m *User) String() string {
|
||||
builder.WriteString("totp_enabled_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sora_storage_quota_bytes=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sora_storage_used_bytes=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -43,10 +43,6 @@ const (
|
||||
FieldTotpEnabled = "totp_enabled"
|
||||
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||
FieldTotpEnabledAt = "totp_enabled_at"
|
||||
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
|
||||
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
|
||||
// FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
|
||||
FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -156,8 +152,6 @@ var Columns = []string{
|
||||
FieldTotpSecretEncrypted,
|
||||
FieldTotpEnabled,
|
||||
FieldTotpEnabledAt,
|
||||
FieldSoraStorageQuotaBytes,
|
||||
FieldSoraStorageUsedBytes,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -214,10 +208,6 @@ var (
|
||||
DefaultNotes string
|
||||
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||
DefaultTotpEnabled bool
|
||||
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
|
||||
DefaultSoraStorageQuotaBytes int64
|
||||
// DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
|
||||
DefaultSoraStorageUsedBytes int64
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the User queries.
|
||||
@@ -298,16 +288,6 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
|
||||
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
|
||||
func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -125,16 +125,6 @@ func TotpEnabledAt(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
|
||||
func SoraStorageQuotaBytes(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
|
||||
func SoraStorageUsedBytes(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -870,86 +860,6 @@ func TotpEnabledAtNotNil() predicate.User {
|
||||
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesEQ(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesGT(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesGTE(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesLT(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
|
||||
func SoraStorageQuotaBytesLTE(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesEQ(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesNEQ(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesGT(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesGTE(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesLT(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
|
||||
func SoraStorageUsedBytesLTE(v int64) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@@ -210,34 +210,6 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
|
||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraStorageQuotaBytes(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||
func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
|
||||
_c.mutation.SetSoraStorageUsedBytes(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetSoraStorageUsedBytes(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -452,14 +424,6 @@ func (_c *UserCreate) defaults() error {
|
||||
v := user.DefaultTotpEnabled
|
||||
_c.mutation.SetTotpEnabled(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||
v := user.DefaultSoraStorageQuotaBytes
|
||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
|
||||
v := user.DefaultSoraStorageUsedBytes
|
||||
_c.mutation.SetSoraStorageUsedBytes(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -523,12 +487,6 @@ func (_c *UserCreate) check() error {
|
||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
|
||||
return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -612,14 +570,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
_node.TotpEnabledAt = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
|
||||
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
_node.SoraStorageQuotaBytes = value
|
||||
}
|
||||
if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
|
||||
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||
_node.SoraStorageUsedBytes = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1006,42 +956,6 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
|
||||
u.Set(user.FieldSoraStorageQuotaBytes, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
|
||||
u.SetExcluded(user.FieldSoraStorageQuotaBytes)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||
func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
|
||||
u.Add(user.FieldSoraStorageQuotaBytes, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||
func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
|
||||
u.Set(user.FieldSoraStorageUsedBytes, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
|
||||
u.SetExcluded(user.FieldSoraStorageUsedBytes)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
||||
func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
|
||||
u.Add(user.FieldSoraStorageUsedBytes, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1304,48 +1218,6 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||
func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.AddSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateSoraStorageQuotaBytes()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||
func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetSoraStorageUsedBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
||||
func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.AddSoraStorageUsedBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateSoraStorageUsedBytes()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -1774,48 +1646,6 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||
func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.AddSoraStorageQuotaBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateSoraStorageQuotaBytes()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||
func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetSoraStorageUsedBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
||||
func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.AddSoraStorageUsedBytes(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateSoraStorageUsedBytes()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -242,48 +242,6 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
|
||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraStorageQuotaBytes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||
func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
|
||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||
func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
|
||||
_u.mutation.ResetSoraStorageUsedBytes()
|
||||
_u.mutation.SetSoraStorageUsedBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetSoraStorageUsedBytes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
|
||||
func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
|
||||
_u.mutation.AddSoraStorageUsedBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -751,18 +709,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.TotpEnabledAtCleared() {
|
||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
|
||||
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
|
||||
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1406,48 +1352,6 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||
func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
|
||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraStorageQuotaBytes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||
func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
|
||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||
func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
|
||||
_u.mutation.ResetSoraStorageUsedBytes()
|
||||
_u.mutation.SetSoraStorageUsedBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSoraStorageUsedBytes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
|
||||
func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
|
||||
_u.mutation.AddSoraStorageUsedBytes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1945,18 +1849,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
if _u.mutation.TotpEnabledAtCleared() {
|
||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
|
||||
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
|
||||
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -7,7 +7,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -66,7 +66,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.1 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
|
||||
@@ -22,8 +22,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||
@@ -58,8 +58,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
@@ -199,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
|
||||
@@ -77,7 +77,6 @@ type Config struct {
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Sora SoraConfig `mapstructure:"sora"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
@@ -197,8 +196,6 @@ type TokenRefreshConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
||||
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@@ -303,59 +300,6 @@ type ConcurrencyConfig struct {
|
||||
PingInterval int `mapstructure:"ping_interval"`
|
||||
}
|
||||
|
||||
// SoraConfig 直连 Sora 配置
|
||||
type SoraConfig struct {
|
||||
Client SoraClientConfig `mapstructure:"client"`
|
||||
Storage SoraStorageConfig `mapstructure:"storage"`
|
||||
}
|
||||
|
||||
// SoraClientConfig 直连 Sora 客户端配置
|
||||
type SoraClientConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
|
||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
||||
Headers map[string]string `mapstructure:"headers"`
|
||||
UserAgent string `mapstructure:"user_agent"`
|
||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
|
||||
}
|
||||
|
||||
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
|
||||
type SoraCurlCFFISidecarConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Impersonate string `mapstructure:"impersonate"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
|
||||
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
|
||||
}
|
||||
|
||||
// SoraStorageConfig 媒体存储配置
|
||||
type SoraStorageConfig struct {
|
||||
Type string `mapstructure:"type"`
|
||||
LocalPath string `mapstructure:"local_path"`
|
||||
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
||||
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
||||
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
|
||||
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
||||
}
|
||||
|
||||
// SoraStorageCleanupConfig 媒体清理配置
|
||||
type SoraStorageCleanupConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Schedule string `mapstructure:"schedule"`
|
||||
RetentionDays int `mapstructure:"retention_days"`
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
@@ -424,24 +368,6 @@ type GatewayConfig struct {
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
|
||||
// Sora 专用配置
|
||||
// SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
|
||||
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
|
||||
// SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
|
||||
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
|
||||
// SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
|
||||
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
|
||||
// SoraStreamMode: stream 强制策略(force/error)
|
||||
SoraStreamMode string `mapstructure:"sora_stream_mode"`
|
||||
// SoraModelFilters: 模型列表过滤配置
|
||||
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
|
||||
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
|
||||
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
|
||||
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
|
||||
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
|
||||
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
|
||||
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
|
||||
|
||||
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||
@@ -639,12 +565,6 @@ type GatewayUsageRecordConfig struct {
|
||||
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
|
||||
}
|
||||
|
||||
// SoraModelFiltersConfig Sora 模型过滤配置
|
||||
type SoraModelFiltersConfig struct {
|
||||
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
|
||||
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||
type TLSFingerprintConfig struct {
|
||||
@@ -656,17 +576,33 @@ type TLSFingerprintConfig struct {
|
||||
}
|
||||
|
||||
// TLSProfileConfig 单个TLS指纹模板的配置
|
||||
// 所有列表字段为空时使用内置默认值(Claude CLI 2.x / Node.js 20.x)
|
||||
// 建议通过 TLS 指纹采集工具 (tests/tls-fingerprint-web) 获取完整配置
|
||||
type TLSProfileConfig struct {
|
||||
// Name: 模板显示名称
|
||||
Name string `mapstructure:"name"`
|
||||
// EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用)
|
||||
EnableGREASE bool `mapstructure:"enable_grease"`
|
||||
// CipherSuites: TLS加密套件列表(空则使用内置默认值)
|
||||
// CipherSuites: TLS加密套件列表
|
||||
CipherSuites []uint16 `mapstructure:"cipher_suites"`
|
||||
// Curves: 椭圆曲线列表(空则使用内置默认值)
|
||||
// Curves: 椭圆曲线列表
|
||||
Curves []uint16 `mapstructure:"curves"`
|
||||
// PointFormats: 点格式列表(空则使用内置默认值)
|
||||
PointFormats []uint8 `mapstructure:"point_formats"`
|
||||
// PointFormats: 点格式列表
|
||||
PointFormats []uint16 `mapstructure:"point_formats"`
|
||||
// SignatureAlgorithms: 签名算法列表
|
||||
SignatureAlgorithms []uint16 `mapstructure:"signature_algorithms"`
|
||||
// ALPNProtocols: ALPN协议列表(如 ["h2", "http/1.1"])
|
||||
ALPNProtocols []string `mapstructure:"alpn_protocols"`
|
||||
// SupportedVersions: 支持的TLS版本列表(如 [0x0304, 0x0303] 即 TLS1.3, TLS1.2)
|
||||
SupportedVersions []uint16 `mapstructure:"supported_versions"`
|
||||
// KeyShareGroups: Key Share中发送的曲线组(如 [29] 即 X25519)
|
||||
KeyShareGroups []uint16 `mapstructure:"key_share_groups"`
|
||||
// PSKModes: PSK密钥交换模式(如 [1] 即 psk_dhe_ke)
|
||||
PSKModes []uint16 `mapstructure:"psk_modes"`
|
||||
// Extensions: TLS扩展类型ID列表,按发送顺序排列
|
||||
// 空则使用内置默认顺序 [0,11,10,35,16,22,23,13,43,45,51]
|
||||
// GREASE值(如0x0a0a)会自动插入GREASE扩展
|
||||
Extensions []uint16 `mapstructure:"extensions"`
|
||||
}
|
||||
|
||||
// GatewaySchedulingConfig accounts scheduling configuration.
|
||||
@@ -934,9 +870,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1264,8 +1201,8 @@ func setDefaults() {
|
||||
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.data_dir", "./data")
|
||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||
@@ -1301,6 +1238,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1384,13 +1322,6 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||
viper.SetDefault("gateway.sora_stream_mode", "force")
|
||||
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
|
||||
viper.SetDefault("gateway.sora_media_require_api_key", true)
|
||||
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
|
||||
@@ -1447,45 +1378,12 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// Sora 直连配置
|
||||
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
||||
viper.SetDefault("sora.client.timeout_seconds", 120)
|
||||
viper.SetDefault("sora.client.max_retries", 3)
|
||||
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
|
||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||
viper.SetDefault("sora.client.debug", false)
|
||||
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
|
||||
|
||||
viper.SetDefault("sora.storage.type", "local")
|
||||
viper.SetDefault("sora.storage.local_path", "")
|
||||
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
||||
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
||||
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
|
||||
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
|
||||
viper.SetDefault("sora.storage.debug", false)
|
||||
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
||||
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
||||
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
||||
|
||||
// Gemini OAuth - configure via environment variables or config file
|
||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||
@@ -1758,6 +1656,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1684,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
@@ -1847,86 +1759,6 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.SoraMaxBodySize < 0 {
|
||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
|
||||
}
|
||||
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
|
||||
switch mode {
|
||||
case "force", "error":
|
||||
default:
|
||||
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
|
||||
}
|
||||
}
|
||||
if c.Sora.Client.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.MaxRetries < 0 {
|
||||
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.PollIntervalSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.MaxPollAttempts < 0 {
|
||||
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimit < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
|
||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
|
||||
}
|
||||
if !c.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
|
||||
}
|
||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.MaxDownloadBytes < 0 {
|
||||
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.Cleanup.Enabled {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
|
||||
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
|
||||
}
|
||||
} else {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
|
||||
}
|
||||
}
|
||||
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
|
||||
return fmt.Errorf("sora.storage.type must be 'local'")
|
||||
}
|
||||
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
||||
switch c.Gateway.ConnectionPoolIsolation {
|
||||
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
@@ -1534,94 +1554,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
|
||||
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
|
||||
}
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
|
||||
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
|
||||
@@ -22,7 +22,6 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
@@ -31,6 +30,7 @@ const (
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -81,8 +81,8 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
// Claude Haiku → Sonnet(无 Haiku 支持)
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5": "claude-sonnet-4-6",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
|
||||
// Gemini 2.5 白名单
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
@@ -113,3 +113,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||
var DefaultBedrockModelMapping = map[string]string{
|
||||
// Claude Opus
|
||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
// Claude Sonnet
|
||||
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
// Claude Haiku
|
||||
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
}
|
||||
|
||||
@@ -267,6 +267,9 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// 收集需要异步设置隐私的 Antigravity OAuth 账号
|
||||
var privacyAccounts []*service.Account
|
||||
|
||||
for i := range dataPayload.Accounts {
|
||||
item := dataPayload.Accounts[i]
|
||||
if err := validateDataAccount(item); err != nil {
|
||||
@@ -314,7 +317,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
SkipDefaultGroupBind: skipDefaultGroupBind,
|
||||
}
|
||||
|
||||
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
|
||||
created, err := h.adminService.CreateAccount(ctx, accountInput)
|
||||
if err != nil {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
@@ -323,9 +327,30 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
})
|
||||
continue
|
||||
}
|
||||
// 收集 Antigravity OAuth 账号,稍后异步设置隐私
|
||||
if created.Platform == service.PlatformAntigravity && created.Type == service.AccountTypeOAuth {
|
||||
privacyAccounts = append(privacyAccounts, created)
|
||||
}
|
||||
result.AccountCreated++
|
||||
}
|
||||
|
||||
// 异步设置 Antigravity 隐私,避免大量导入时阻塞请求
|
||||
if len(privacyAccounts) > 0 {
|
||||
adminSvc := h.adminService
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("import_antigravity_privacy_panic", "recover", r)
|
||||
}
|
||||
}()
|
||||
bgCtx := context.Background()
|
||||
for _, acc := range privacyAccounts {
|
||||
adminSvc.ForceAntigravityPrivacy(bgCtx, acc)
|
||||
}
|
||||
slog.Info("import_antigravity_privacy_done", "count", len(privacyAccounts))
|
||||
}()
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -352,7 +377,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
||||
pageSize := dataPageCap
|
||||
var out []service.Account
|
||||
for {
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -542,15 +567,15 @@ func defaultProxyName(name string) string {
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Only applies to OpenAI OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
// Only enrich OpenAI OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
if platform != service.PlatformOpenAI {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -97,7 +98,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +117,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -165,6 +166,8 @@ type AccountWithConcurrency struct {
|
||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||
}
|
||||
|
||||
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||
|
||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(account),
|
||||
@@ -217,6 +220,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
@@ -226,10 +230,23 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if groupIDStr == accountListGroupUngroupedQueryValue {
|
||||
groupID = service.AccountListGroupUngrouped
|
||||
} else {
|
||||
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||
return
|
||||
}
|
||||
if parsedGroupID < 0 {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
|
||||
return
|
||||
}
|
||||
groupID = parsedGroupID
|
||||
}
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -520,6 +537,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
// Antigravity OAuth: 新账号直接设置隐私
|
||||
h.adminService.ForceAntigravityPrivacy(ctx, account)
|
||||
// OpenAI OAuth: 新账号直接设置隐私
|
||||
h.adminService.ForceOpenAIPrivacy(ctx, account)
|
||||
return h.buildAccountResponseWithRuntime(ctx, account), nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -766,6 +787,8 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
||||
if account.IsOpenAI() {
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败但 access_token 可能仍有效,尝试设置隐私
|
||||
h.adminService.EnsureOpenAIPrivacy(ctx, account)
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -816,6 +839,7 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
||||
if updateErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
h.adminService.EnsureAntigravityPrivacy(ctx, updatedAccount)
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
@@ -865,6 +889,11 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||
// Antigravity OAuth: 刷新成功后检查并设置 privacy_mode
|
||||
h.adminService.EnsureAntigravityPrivacy(ctx, updatedAccount)
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
@@ -1135,6 +1164,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
// 收集需要异步设置隐私的 OAuth 账号
|
||||
var antigravityPrivacyAccounts []*service.Account
|
||||
var openaiPrivacyAccounts []*service.Account
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
@@ -1177,6 +1209,15 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
})
|
||||
continue
|
||||
}
|
||||
// 收集需要异步设置隐私的 OAuth 账号
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
switch account.Platform {
|
||||
case service.PlatformAntigravity:
|
||||
antigravityPrivacyAccounts = append(antigravityPrivacyAccounts, account)
|
||||
case service.PlatformOpenAI:
|
||||
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
|
||||
}
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
@@ -1185,6 +1226,37 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// 异步设置隐私,避免批量创建时阻塞请求
|
||||
adminSvc := h.adminService
|
||||
if len(antigravityPrivacyAccounts) > 0 {
|
||||
accounts := antigravityPrivacyAccounts
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("batch_create_antigravity_privacy_panic", "recover", r)
|
||||
}
|
||||
}()
|
||||
bgCtx := context.Background()
|
||||
for _, acc := range accounts {
|
||||
adminSvc.ForceAntigravityPrivacy(bgCtx, acc)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(openaiPrivacyAccounts) > 0 {
|
||||
accounts := openaiPrivacyAccounts
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("batch_create_openai_privacy_panic", "recover", r)
|
||||
}
|
||||
}()
|
||||
bgCtx := context.Background()
|
||||
for _, acc := range accounts {
|
||||
adminSvc.ForceOpenAIPrivacy(bgCtx, acc)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
@@ -1493,7 +1565,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
||||
}
|
||||
|
||||
// GetUsage handles getting account usage information
|
||||
// GET /api/v1/admin/accounts/:id/usage
|
||||
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -1501,7 +1573,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
source := c.DefaultQuery("source", "active")
|
||||
|
||||
var usage *service.UsageInfo
|
||||
if source == "passive" {
|
||||
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||
} else {
|
||||
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
}
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -1715,13 +1794,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
@@ -1797,12 +1875,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Sora accounts
|
||||
if account.Platform == service.PlatformSora {
|
||||
response.Success(c, service.DefaultSoraModels(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
@@ -1844,6 +1916,51 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
if account.Type != service.AccountTypeOAuth {
|
||||
response.BadRequest(c, "Only OAuth accounts support privacy setting")
|
||||
return
|
||||
}
|
||||
var mode string
|
||||
switch account.Platform {
|
||||
case service.PlatformOpenAI:
|
||||
mode = h.adminService.ForceOpenAIPrivacy(c.Request.Context(), account)
|
||||
case service.PlatformAntigravity:
|
||||
mode = h.adminService.ForceAntigravityPrivacy(c.Request.Context(), account)
|
||||
default:
|
||||
response.BadRequest(c, "Only OpenAI and Antigravity OAuth accounts support privacy setting")
|
||||
return
|
||||
}
|
||||
if mode == "" {
|
||||
response.BadRequest(c, "Cannot set privacy: missing access_token")
|
||||
return
|
||||
}
|
||||
// 从 DB 重新读取以确保返回最新状态
|
||||
updated, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
// 隐私已设置成功但读取失败,回退到内存更新
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
account.Extra["privacy_mode"] = mode
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
return
|
||||
}
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updated))
|
||||
}
|
||||
|
||||
// RefreshTier handles refreshing Google One tier for a single account
|
||||
// POST /api/v1/admin/accounts/:id/refresh-tier
|
||||
func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
||||
@@ -1912,7 +2029,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||
|
||||
|
||||
@@ -175,7 +175,19 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
@@ -368,7 +380,6 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
|
||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -429,5 +440,25 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *stubAdminService) EnsureAntigravityPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ForceOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ForceAntigravityPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*service.ReplaceUserGroupResult, error) {
|
||||
return &service.ReplaceUserGroupResult{MigratedKeys: 0}, nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
205
backend/internal/handler/admin/backup_handler.go
Normal file
205
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Accepted(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Accepted(c, record)
|
||||
}
|
||||
452
backend/internal/handler/admin/channel_handler.go
Normal file
452
backend/internal/handler/admin/channel_handler.go
Normal file
@@ -0,0 +1,452 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChannelHandler handles admin channel management
|
||||
type ChannelHandler struct {
|
||||
channelService *service.ChannelService
|
||||
billingService *service.BillingService
|
||||
}
|
||||
|
||||
// NewChannelHandler creates a new admin channel handler
|
||||
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService, billingService: billingService}
|
||||
}
|
||||
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
Platform string `json:"platform" binding:"omitempty,max=50"`
|
||||
Models []string `json:"models" binding:"required,min=1,max=100"`
|
||||
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
||||
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
||||
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
|
||||
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
|
||||
Intervals []pricingIntervalRequest `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalRequest struct {
|
||||
MinTokens int `json:"min_tokens"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
TierLabel string `json:"tier_label"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
Intervals []pricingIntervalResponse `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
MinTokens int `json:"min_tokens"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
TierLabel string `json:"tier_label,omitempty"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
if ch == nil {
|
||||
return nil
|
||||
}
|
||||
resp := &channelResponse{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
resp.BillingModelSource = ch.BillingModelSource
|
||||
if resp.BillingModelSource == "" {
|
||||
resp.BillingModelSource = service.BillingModelSourceChannelMapped
|
||||
}
|
||||
if resp.GroupIDs == nil {
|
||||
resp.GroupIDs = []int64{}
|
||||
}
|
||||
if resp.ModelMapping == nil {
|
||||
resp.ModelMapping = map[string]map[string]string{}
|
||||
}
|
||||
|
||||
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
|
||||
for _, p := range ch.ModelPricing {
|
||||
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse {
|
||||
models := p.Models
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
billingMode := string(p.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = string(service.BillingModeToken)
|
||||
}
|
||||
platform := p.Platform
|
||||
if platform == "" {
|
||||
platform = service.PlatformAnthropic
|
||||
}
|
||||
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
|
||||
for _, iv := range p.Intervals {
|
||||
intervals = append(intervals, intervalToResponse(iv))
|
||||
}
|
||||
return channelModelPricingResponse{
|
||||
ID: p.ID,
|
||||
Platform: platform,
|
||||
Models: models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: p.InputPrice,
|
||||
OutputPrice: p.OutputPrice,
|
||||
CacheWritePrice: p.CacheWritePrice,
|
||||
CacheReadPrice: p.CacheReadPrice,
|
||||
ImageOutputPrice: p.ImageOutputPrice,
|
||||
PerRequestPrice: p.PerRequestPrice,
|
||||
Intervals: intervals,
|
||||
}
|
||||
}
|
||||
|
||||
func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse {
|
||||
return pricingIntervalResponse{
|
||||
ID: iv.ID,
|
||||
MinTokens: iv.MinTokens,
|
||||
MaxTokens: iv.MaxTokens,
|
||||
TierLabel: iv.TierLabel,
|
||||
InputPrice: iv.InputPrice,
|
||||
OutputPrice: iv.OutputPrice,
|
||||
CacheWritePrice: iv.CacheWritePrice,
|
||||
CacheReadPrice: iv.CacheReadPrice,
|
||||
PerRequestPrice: iv.PerRequestPrice,
|
||||
SortOrder: iv.SortOrder,
|
||||
}
|
||||
}
|
||||
|
||||
func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
|
||||
result := make([]service.ChannelModelPricing, 0, len(reqs))
|
||||
for _, r := range reqs {
|
||||
billingMode := service.BillingMode(r.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := r.Platform
|
||||
if platform == "" {
|
||||
platform = service.PlatformAnthropic
|
||||
}
|
||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||
for _, iv := range r.Intervals {
|
||||
intervals = append(intervals, service.PricingInterval{
|
||||
MinTokens: iv.MinTokens,
|
||||
MaxTokens: iv.MaxTokens,
|
||||
TierLabel: iv.TierLabel,
|
||||
InputPrice: iv.InputPrice,
|
||||
OutputPrice: iv.OutputPrice,
|
||||
CacheWritePrice: iv.CacheWritePrice,
|
||||
CacheReadPrice: iv.CacheReadPrice,
|
||||
PerRequestPrice: iv.PerRequestPrice,
|
||||
SortOrder: iv.SortOrder,
|
||||
})
|
||||
}
|
||||
result = append(result, service.ChannelModelPricing{
|
||||
Platform: platform,
|
||||
Models: r.Models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: r.InputPrice,
|
||||
OutputPrice: r.OutputPrice,
|
||||
CacheWritePrice: r.CacheWritePrice,
|
||||
CacheReadPrice: r.CacheReadPrice,
|
||||
ImageOutputPrice: r.ImageOutputPrice,
|
||||
PerRequestPrice: r.PerRequestPrice,
|
||||
Intervals: intervals,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// validatePricingBillingMode 校验计费配置
|
||||
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
|
||||
for _, p := range pricing {
|
||||
// 按次/图片模式必须配置默认价格或区间
|
||||
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||
return errors.New("per-request price or intervals required for per_request/image billing mode")
|
||||
}
|
||||
}
|
||||
// 校验价格不能为负
|
||||
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验 interval:至少有一个价格字段非空
|
||||
for _, iv := range p.Intervals {
|
||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||
iv.PerRequestPrice == nil {
|
||||
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
|
||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePriceNotNegative(field string, val *float64) error {
|
||||
if val != nil && *val < 0 {
|
||||
return fmt.Errorf("%s must be >= 0", field)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokens(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// List handles listing channels with pagination
|
||||
// GET /api/v1/admin/channels
|
||||
func (h *ChannelHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*channelResponse, 0, len(channels))
|
||||
for i := range channels {
|
||||
out = append(out, channelToResponse(&channels[i]))
|
||||
}
|
||||
response.Paginated(c, out, pag.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a channel by ID
|
||||
// GET /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) GetByID(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Create handles creating a new channel
|
||||
// POST /api/v1/admin/channels
|
||||
func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
var req createChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
pricing := pricingRequestToService(req.ModelPricing)
|
||||
if err := validatePricingBillingMode(pricing); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Update handles updating a channel
|
||||
// PUT /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
||||
return
|
||||
}
|
||||
|
||||
var req updateChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
if err := validatePricingBillingMode(pricing); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
input.ModelPricing = &pricing
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Delete handles deleting a channel
|
||||
// DELETE /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Channel deleted successfully"})
|
||||
}
|
||||
|
||||
// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充)
|
||||
// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4
|
||||
func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
|
||||
model := strings.TrimSpace(c.Query("model"))
|
||||
if model == "" {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required").
|
||||
WithMetadata(map[string]string{"param": "model"}))
|
||||
return
|
||||
}
|
||||
|
||||
pricing, err := h.billingService.GetModelPricing(model)
|
||||
if err != nil {
|
||||
// 模型不在定价列表中
|
||||
response.Success(c, gin.H{"found": false})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"found": true,
|
||||
"input_price": pricing.InputPricePerToken,
|
||||
"output_price": pricing.OutputPricePerToken,
|
||||
"cache_write_price": pricing.CacheCreationPricePerToken,
|
||||
"cache_read_price": pricing.CacheReadPricePerToken,
|
||||
"image_output_price": pricing.ImageOutputPricePerToken,
|
||||
})
|
||||
}
|
||||
502
backend/internal/handler/admin/channel_handler_test.go
Normal file
502
backend/internal/handler/admin/channel_handler_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. channelToResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChannelToResponse_NilInput(t *testing.T) {
|
||||
require.Nil(t, channelToResponse(nil))
|
||||
}
|
||||
|
||||
func TestChannelToResponse_FullChannel(t *testing.T) {
|
||||
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
|
||||
ch := &service.Channel{
|
||||
ID: 42,
|
||||
Name: "test-channel",
|
||||
Description: "desc",
|
||||
Status: "active",
|
||||
BillingModelSource: "upstream",
|
||||
RestrictModels: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(time.Hour),
|
||||
GroupIDs: []int64{1, 2, 3},
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 10,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4"},
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.03),
|
||||
CacheWritePrice: float64Ptr(0.005),
|
||||
CacheReadPrice: float64Ptr(0.002),
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(42), resp.ID)
|
||||
require.Equal(t, "test-channel", resp.Name)
|
||||
require.Equal(t, "desc", resp.Description)
|
||||
require.Equal(t, "active", resp.Status)
|
||||
require.Equal(t, "upstream", resp.BillingModelSource)
|
||||
require.True(t, resp.RestrictModels)
|
||||
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
|
||||
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
|
||||
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
|
||||
|
||||
// model mapping
|
||||
require.Len(t, resp.ModelMapping, 1)
|
||||
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
|
||||
|
||||
// pricing
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
p := resp.ModelPricing[0]
|
||||
require.Equal(t, int64(10), p.ID)
|
||||
require.Equal(t, "openai", p.Platform)
|
||||
require.Equal(t, []string{"gpt-4"}, p.Models)
|
||||
require.Equal(t, "token", p.BillingMode)
|
||||
require.Equal(t, float64Ptr(0.01), p.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
|
||||
require.Empty(t, p.Intervals)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
|
||||
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
BillingModelSource: "",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
GroupIDs: nil,
|
||||
ModelMapping: nil,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Platform: "",
|
||||
BillingMode: "",
|
||||
Models: []string{"m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Equal(t, "channel_mapped", resp.BillingModelSource)
|
||||
require.NotNil(t, resp.GroupIDs)
|
||||
require.Empty(t, resp.GroupIDs)
|
||||
require.NotNil(t, resp.ModelMapping)
|
||||
require.Empty(t, resp.ModelMapping)
|
||||
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
|
||||
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_NilModels(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Models: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
require.NotNil(t, resp.ModelPricing[0].Models)
|
||||
require.Empty(t, resp.ModelPricing[0].Models)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_WithIntervals(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "ch",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
Intervals: []service.PricingInterval{
|
||||
{
|
||||
ID: 100,
|
||||
MinTokens: 0,
|
||||
MaxTokens: intPtr(1000),
|
||||
TierLabel: "1K",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.02),
|
||||
CacheWritePrice: float64Ptr(0.003),
|
||||
CacheReadPrice: float64Ptr(0.001),
|
||||
PerRequestPrice: float64Ptr(0.1),
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
ID: 101,
|
||||
MinTokens: 1000,
|
||||
MaxTokens: nil,
|
||||
TierLabel: "unlimited",
|
||||
SortOrder: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 1)
|
||||
intervals := resp.ModelPricing[0].Intervals
|
||||
require.Len(t, intervals, 2)
|
||||
|
||||
iv0 := intervals[0]
|
||||
require.Equal(t, int64(100), iv0.ID)
|
||||
require.Equal(t, 0, iv0.MinTokens)
|
||||
require.Equal(t, intPtr(1000), iv0.MaxTokens)
|
||||
require.Equal(t, "1K", iv0.TierLabel)
|
||||
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
|
||||
require.Equal(t, 1, iv0.SortOrder)
|
||||
|
||||
iv1 := intervals[1]
|
||||
require.Equal(t, int64(101), iv1.ID)
|
||||
require.Equal(t, 1000, iv1.MinTokens)
|
||||
require.Nil(t, iv1.MaxTokens)
|
||||
require.Equal(t, "unlimited", iv1.TierLabel)
|
||||
require.Equal(t, 2, iv1.SortOrder)
|
||||
}
|
||||
|
||||
func TestChannelToResponse_MultipleEntries(t *testing.T) {
|
||||
now := time.Now()
|
||||
ch := &service.Channel{
|
||||
ID: 1,
|
||||
Name: "multi",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: "anthropic",
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.003),
|
||||
OutputPrice: float64Ptr(0.015),
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4", "gpt-4o"},
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(1.0),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Platform: "gemini",
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
BillingMode: service.BillingModeImage,
|
||||
ImageOutputPrice: float64Ptr(0.05),
|
||||
PerRequestPrice: float64Ptr(0.2),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Len(t, resp.ModelPricing, 3)
|
||||
|
||||
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
|
||||
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
|
||||
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
|
||||
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
|
||||
|
||||
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
|
||||
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
|
||||
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
|
||||
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
|
||||
|
||||
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
|
||||
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
|
||||
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
|
||||
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
|
||||
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. pricingRequestToService
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestPricingRequestToService_Defaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req channelModelPricingRequest
|
||||
wantField string // which default field to check
|
||||
wantValue string
|
||||
}{
|
||||
{
|
||||
name: "empty billing mode defaults to token",
|
||||
req: channelModelPricingRequest{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "",
|
||||
},
|
||||
wantField: "BillingMode",
|
||||
wantValue: string(service.BillingModeToken),
|
||||
},
|
||||
{
|
||||
name: "empty platform defaults to anthropic",
|
||||
req: channelModelPricingRequest{
|
||||
Models: []string{"m1"},
|
||||
Platform: "",
|
||||
},
|
||||
wantField: "Platform",
|
||||
wantValue: "anthropic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
|
||||
require.Len(t, result, 1)
|
||||
switch tt.wantField {
|
||||
case "BillingMode":
|
||||
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
|
||||
case "Platform":
|
||||
require.Equal(t, tt.wantValue, result[0].Platform)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_WithAllFields(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4", "gpt-4o"},
|
||||
BillingMode: "per_request",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.03),
|
||||
CacheWritePrice: float64Ptr(0.005),
|
||||
CacheReadPrice: float64Ptr(0.002),
|
||||
ImageOutputPrice: float64Ptr(0.04),
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
r := result[0]
|
||||
require.Equal(t, "openai", r.Platform)
|
||||
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
|
||||
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
|
||||
require.Equal(t, float64Ptr(0.01), r.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
|
||||
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_WithIntervals(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "per_request",
|
||||
Intervals: []pricingIntervalRequest{
|
||||
{
|
||||
MinTokens: 0,
|
||||
MaxTokens: intPtr(2000),
|
||||
TierLabel: "small",
|
||||
InputPrice: float64Ptr(0.01),
|
||||
OutputPrice: float64Ptr(0.02),
|
||||
CacheWritePrice: float64Ptr(0.003),
|
||||
CacheReadPrice: float64Ptr(0.001),
|
||||
PerRequestPrice: float64Ptr(0.1),
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
MinTokens: 2000,
|
||||
MaxTokens: nil,
|
||||
TierLabel: "large",
|
||||
SortOrder: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
require.Len(t, result[0].Intervals, 2)
|
||||
|
||||
iv0 := result[0].Intervals[0]
|
||||
require.Equal(t, 0, iv0.MinTokens)
|
||||
require.Equal(t, intPtr(2000), iv0.MaxTokens)
|
||||
require.Equal(t, "small", iv0.TierLabel)
|
||||
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
|
||||
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
|
||||
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
|
||||
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
|
||||
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
|
||||
require.Equal(t, 1, iv0.SortOrder)
|
||||
|
||||
iv1 := result[0].Intervals[1]
|
||||
require.Equal(t, 2000, iv1.MinTokens)
|
||||
require.Nil(t, iv1.MaxTokens)
|
||||
require.Equal(t, "large", iv1.TierLabel)
|
||||
require.Equal(t, 2, iv1.SortOrder)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_EmptySlice(t *testing.T) {
|
||||
result := pricingRequestToService([]channelModelPricingRequest{})
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
|
||||
reqs := []channelModelPricingRequest{
|
||||
{
|
||||
Models: []string{"m1"},
|
||||
BillingMode: "token",
|
||||
// all price fields are nil by default
|
||||
},
|
||||
}
|
||||
|
||||
result := pricingRequestToService(reqs)
|
||||
require.Len(t, result, 1)
|
||||
r := result[0]
|
||||
require.Nil(t, r.InputPrice)
|
||||
require.Nil(t, r.OutputPrice)
|
||||
require.Nil(t, r.CacheWritePrice)
|
||||
require.Nil(t, r.CacheReadPrice)
|
||||
require.Nil(t, r.ImageOutputPrice)
|
||||
require.Nil(t, r.PerRequestPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. validatePricingBillingMode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidatePricingBillingMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pricing []service.ChannelModelPricing
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "token mode - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModeToken},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request with price - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request with intervals - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
Intervals: []service.PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "per_request no price no intervals - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModePerRequest},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "image with price - valid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModeImage,
|
||||
PerRequestPrice: float64Ptr(0.2),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "image no price no intervals - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{BillingMode: service.BillingModeImage},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list - valid",
|
||||
pricing: []service.ChannelModelPricing{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed modes with invalid image - invalid",
|
||||
pricing: []service.ChannelModelPricing{
|
||||
{
|
||||
BillingMode: service.BillingModeToken,
|
||||
InputPrice: float64Ptr(0.01),
|
||||
},
|
||||
{
|
||||
BillingMode: service.BillingModePerRequest,
|
||||
PerRequestPrice: float64Ptr(0.5),
|
||||
},
|
||||
{
|
||||
BillingMode: service.BillingModeImage,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validatePricingBillingMode(tt.pricing)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "per-request price or intervals required")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -272,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
modelSource := usagestats.ModelSourceRequested
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
@@ -296,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
|
||||
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||
return
|
||||
}
|
||||
modelSource = rawModelSource
|
||||
}
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
@@ -322,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
@@ -466,9 +475,62 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func parseRankingLimit(raw string) int {
|
||||
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || limit <= 0 {
|
||||
return 12
|
||||
}
|
||||
if limit > 50 {
|
||||
return 50
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||
// GET /api/v1/admin/dashboard/users-ranking
|
||||
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Limit int `json:"limit"`
|
||||
}{
|
||||
Start: startTime.UTC().Format(time.RFC3339),
|
||||
End: endTime.UTC().Format(time.RFC3339),
|
||||
Limit: limit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user spending ranking")
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"total_requests": ranking.TotalRequests,
|
||||
"total_tokens": ranking.TotalTokens,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
@@ -551,3 +613,81 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
|
||||
// GET /api/v1/admin/dashboard/user-breakdown
|
||||
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
|
||||
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
dim := usagestats.UserBreakdownDimension{}
|
||||
if v := c.Query("group_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.GroupID = id
|
||||
}
|
||||
}
|
||||
dim.Model = c.Query("model")
|
||||
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
|
||||
if !usagestats.IsValidModelSource(rawModelSource) {
|
||||
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
|
||||
return
|
||||
}
|
||||
dim.ModelType = rawModelSource
|
||||
dim.Endpoint = c.Query("endpoint")
|
||||
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
|
||||
|
||||
// Additional filter conditions
|
||||
if v := c.Query("user_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.UserID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("api_key_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.APIKeyID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("account_id"); v != "" {
|
||||
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
dim.AccountID = id
|
||||
}
|
||||
}
|
||||
if v := c.Query("request_type"); v != "" {
|
||||
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
|
||||
rtVal := int16(rt)
|
||||
dim.RequestType = &rtVal
|
||||
}
|
||||
}
|
||||
if v := c.Query("stream"); v != "" {
|
||||
if s, err := strconv.ParseBool(v); err == nil {
|
||||
dim.Stream = &s
|
||||
}
|
||||
}
|
||||
if v := c.Query("billing_type"); v != "" {
|
||||
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
|
||||
btVal := int8(bt)
|
||||
dim.BillingType = &btVal
|
||||
}
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetUserBreakdownStats(
|
||||
c.Request.Context(), startTime, endTime, dim, limit,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user breakdown stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"users": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
rankingLimit int
|
||||
ranking []usagestats.UserSpendingRankingItem
|
||||
rankingTotal float64
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
@@ -49,6 +52,20 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
limit int,
|
||||
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
s.rankingLimit = limit
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
TotalRequests: 44,
|
||||
TotalTokens: 1234,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
@@ -56,6 +73,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -130,3 +148,54 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsValidModelSource(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
repo := &dashboardUsageRepoCapture{
|
||||
ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||
},
|
||||
rankingTotal: 88.8,
|
||||
}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock repo ---
|
||||
|
||||
type userBreakdownRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
capturedDim usagestats.UserBreakdownDimension
|
||||
capturedLimit int
|
||||
result []usagestats.UserBreakdownItem
|
||||
}
|
||||
|
||||
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
|
||||
_ context.Context, _, _ time.Time,
|
||||
dim usagestats.UserBreakdownDimension, limit int,
|
||||
) ([]usagestats.UserBreakdownItem, error) {
|
||||
r.capturedDim = dim
|
||||
r.capturedLimit = limit
|
||||
if r.result != nil {
|
||||
return r.result, nil
|
||||
}
|
||||
return []usagestats.UserBreakdownItem{}, nil
|
||||
}
|
||||
|
||||
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
h := NewDashboardHandler(svc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
|
||||
return router
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, int64(42), repo.capturedDim.GroupID)
|
||||
require.Empty(t, repo.capturedDim.Model)
|
||||
require.Empty(t, repo.capturedDim.Endpoint)
|
||||
require.Equal(t, 50, repo.capturedLimit) // default limit
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
|
||||
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
|
||||
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
|
||||
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 100, repo.capturedLimit)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
// limit > 200 should fall back to default 50
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 50, repo.capturedLimit)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{
|
||||
result: []usagestats.UserBreakdownItem{
|
||||
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
|
||||
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
|
||||
},
|
||||
}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Users, 2)
|
||||
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
|
||||
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
|
||||
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
|
||||
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
|
||||
require.Equal(t, "2026-03-01", resp.Data.StartDate)
|
||||
require.Equal(t, "2026-03-16", resp.Data.EndDate)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
Users []usagestats.UserBreakdownItem `json:"users"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.Data.Users)
|
||||
}
|
||||
|
||||
func TestGetUserBreakdown_NoFilters(t *testing.T) {
|
||||
repo := &userBreakdownRepoCapture{}
|
||||
router := newUserBreakdownRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, int64(0), repo.capturedDim.GroupID)
|
||||
require.Empty(t, repo.capturedDim.Model)
|
||||
require.Empty(t, repo.capturedDim.Endpoint)
|
||||
}
|
||||
@@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
ModelSource string `json:"model_source,omitempty"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
@@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
modelSource string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
@@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
ModelSource: usagestats.NormalizeModelSource(modelSource),
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
|
||||
@@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
usagestats.ModelSourceRequested,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,35 +17,84 @@ import (
|
||||
|
||||
// GroupHandler handles admin group management
|
||||
type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
adminService service.AdminService
|
||||
dashboardService *service.DashboardService
|
||||
groupCapacityService *service.GroupCapacityService
|
||||
}
|
||||
|
||||
type optionalLimitField struct {
|
||||
set bool
|
||||
value *float64
|
||||
}
|
||||
|
||||
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||
f.set = true
|
||||
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var number float64
|
||||
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
number, err = strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||
}
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||
}
|
||||
|
||||
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||
if !f.set {
|
||||
return nil
|
||||
}
|
||||
if f.value != nil {
|
||||
return f.value
|
||||
}
|
||||
zero := 0.0
|
||||
return &zero
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
adminService: adminService,
|
||||
adminService: adminService,
|
||||
dashboardService: dashboardService,
|
||||
groupCapacityService: groupCapacityService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@@ -51,10 +104,10 @@ type CreateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
@@ -62,24 +115,20 @@ type CreateGroupRequest struct {
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@@ -89,10 +138,10 @@ type UpdateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||
RequireOAuthOnly *bool `json:"require_oauth_only"`
|
||||
RequirePrivacySet *bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
@@ -191,16 +240,12 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@@ -208,8 +253,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: req.RequireOAuthOnly,
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
@@ -244,16 +290,12 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@@ -261,8 +303,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: req.RequireOAuthOnly,
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
@@ -311,6 +354,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
|
||||
_ = groupID // TODO: implement actual stats
|
||||
}
|
||||
|
||||
// GetUsageSummary returns today's and cumulative cost for all groups.
|
||||
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
|
||||
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
|
||||
userTZ := c.Query("timezone")
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
|
||||
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group usage summary")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
|
||||
// GET /api/v1/admin/groups/capacity-summary
|
||||
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
|
||||
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group capacity summary")
|
||||
return
|
||||
}
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
// GetGroupAPIKeys handles getting API keys in a group
|
||||
// GET /api/v1/admin/groups/:id/api-keys
|
||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
@@ -335,6 +405,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||
type BatchSetGroupRateMultipliersRequest struct {
|
||||
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRateMultipliersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct {
|
||||
}
|
||||
|
||||
func oauthPlatformFromPath(c *gin.Context) string {
|
||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||
return service.PlatformSora
|
||||
}
|
||||
return service.PlatformOpenAI
|
||||
}
|
||||
|
||||
@@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct {
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
// POST /api/v1/admin/sora/rt2at
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||
// POST /api/v1/admin/sora/st2at
|
||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
ST string `json:"st"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||
if sessionToken == "" {
|
||||
sessionToken = strings.TrimSpace(req.ST)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
response.BadRequest(c, "session_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
// POST /api/v1/admin/sora/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
@@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
if platform == service.PlatformSora {
|
||||
name = "Sora OAuth Account"
|
||||
} else {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
@@ -289,6 +249,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
Extra: nil,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
|
||||
@@ -35,18 +35,21 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"min=0"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||
Value float64 `json:"value"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||
if req.Type == "" {
|
||||
req.Type = "balance"
|
||||
}
|
||||
|
||||
if req.Type == "subscription" {
|
||||
if req.GroupID == nil {
|
||||
response.BadRequest(c, "group_id is required for subscription type")
|
||||
return
|
||||
}
|
||||
if req.ValidityDays == 0 {
|
||||
response.BadRequest(c, "validity_days must not be zero for subscription type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
|
||||
141
backend/internal/handler/admin/redeem_handler_test.go
Normal file
141
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||
// parameter-validation layer that runs before any service call.
|
||||
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: newStubAdminService(),
|
||||
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||
}
|
||||
}
|
||||
|
||||
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||
// status code. For cases that pass validation and proceed into the service layer,
|
||||
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||
code = 0
|
||||
}
|
||||
}()
|
||||
handler.CreateAndRedeem(c)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-default",
|
||||
"value": 10.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"omitting type should default to balance and pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-no-group",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"validity_days": 30,
|
||||
// group_id 缺失
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresNonZeroValidityDays(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
|
||||
// zero should be rejected
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-bad-days-zero",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 0,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
})
|
||||
|
||||
// negative should pass validation (used for refund/reduction)
|
||||
t.Run("negative_passes_validation", func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-negative-days",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": -7,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"negative validity_days should pass validation for refund")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-valid",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 31,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"valid subscription params should pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-no-extras",
|
||||
"type": "balance",
|
||||
"value": 50.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
@@ -41,17 +41,15 @@ type SettingHandler struct {
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
soraS3Storage *service.SoraS3Storage
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
opsService: opsService,
|
||||
soraS3Storage: soraS3Storage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +78,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
FrontendURL: settings.FrontendURL,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -107,8 +106,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
@@ -124,7 +123,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -136,6 +139,7 @@ type UpdateSettingsRequest struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
@@ -170,8 +174,8 @@ type UpdateSettingsRequest struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -196,9 +200,17 @@ type UpdateSettingsRequest struct {
|
||||
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
|
||||
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
MaxClaudeCodeVersion string `json:"max_claude_code_version"`
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -223,11 +235,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
|
||||
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
|
||||
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
|
||||
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
|
||||
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
|
||||
|
||||
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
|
||||
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
|
||||
if req.SMTPHost == "" && previousSettings.SMTPHost != "" {
|
||||
req.SMTPHost = previousSettings.SMTPHost
|
||||
req.SMTPPort = previousSettings.SMTPPort
|
||||
req.SMTPUsername = previousSettings.SMTPUsername
|
||||
req.SMTPFrom = previousSettings.SMTPFrom
|
||||
req.SMTPFromName = previousSettings.SMTPFromName
|
||||
req.SMTPUseTLS = previousSettings.SMTPUseTLS
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
// 检查必填字段
|
||||
@@ -322,6 +350,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Frontend URL 验证
|
||||
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||
if req.FrontendURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
@@ -400,6 +437,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
customMenuJSON = string(menuBytes)
|
||||
}
|
||||
|
||||
// 自定义端点验证
|
||||
const (
|
||||
maxCustomEndpoints = 10
|
||||
maxEndpointNameLen = 50
|
||||
maxEndpointURLLen = 2048
|
||||
maxEndpointDescriptionLen = 200
|
||||
)
|
||||
|
||||
customEndpointsJSON := previousSettings.CustomEndpoints
|
||||
if req.CustomEndpoints != nil {
|
||||
endpoints := *req.CustomEndpoints
|
||||
if len(endpoints) > maxCustomEndpoints {
|
||||
response.BadRequest(c, "Too many custom endpoints (max 10)")
|
||||
return
|
||||
}
|
||||
for _, ep := range endpoints {
|
||||
if strings.TrimSpace(ep.Name) == "" {
|
||||
response.BadRequest(c, "Custom endpoint name is required")
|
||||
return
|
||||
}
|
||||
if len(ep.Name) > maxEndpointNameLen {
|
||||
response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ep.Endpoint) == "" {
|
||||
response.BadRequest(c, "Custom endpoint URL is required")
|
||||
return
|
||||
}
|
||||
if len(ep.Endpoint) > maxEndpointURLLen {
|
||||
response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil {
|
||||
response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
if len(ep.Description) > maxEndpointDescriptionLen {
|
||||
response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)")
|
||||
return
|
||||
}
|
||||
}
|
||||
endpointBytes, err := json.Marshal(endpoints)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to serialize custom endpoints")
|
||||
return
|
||||
}
|
||||
customEndpointsJSON = string(endpointBytes)
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -427,12 +513,29 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证最高版本号格式(空字符串=禁用,或合法 semver)
|
||||
if req.MaxClaudeCodeVersion != "" {
|
||||
if !semverPattern.MatchString(req.MaxClaudeCodeVersion) {
|
||||
response.Error(c, http.StatusBadRequest, "max_claude_code_version must be empty or a valid semver (e.g. 3.0.0)")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
|
||||
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
|
||||
if service.CompareVersions(req.MaxClaudeCodeVersion, req.MinClaudeCodeVersion) < 0 {
|
||||
response.Error(c, http.StatusBadRequest, "max_claude_code_version must be greater than or equal to min_claude_code_version")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
@@ -459,8 +562,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
@@ -472,7 +575,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -497,6 +602,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.OpsMetricsIntervalSeconds
|
||||
}(),
|
||||
EnableFingerprintUnification: func() bool {
|
||||
if req.EnableFingerprintUnification != nil {
|
||||
return *req.EnableFingerprintUnification
|
||||
}
|
||||
return previousSettings.EnableFingerprintUnification
|
||||
}(),
|
||||
EnableMetadataPassthrough: func() bool {
|
||||
if req.EnableMetadataPassthrough != nil {
|
||||
return *req.EnableMetadataPassthrough
|
||||
}
|
||||
return previousSettings.EnableMetadataPassthrough
|
||||
}(),
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -526,6 +643,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
FrontendURL: updatedSettings.FrontendURL,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -553,8 +671,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
@@ -570,7 +688,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -608,6 +730,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.FrontendURL != after.FrontendURL {
|
||||
changed = append(changed, "frontend_url")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
@@ -722,9 +847,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
|
||||
changed = append(changed, "min_claude_code_version")
|
||||
}
|
||||
if before.MaxClaudeCodeVersion != after.MaxClaudeCodeVersion {
|
||||
changed = append(changed, "max_claude_code_version")
|
||||
}
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
@@ -734,6 +865,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.CustomMenuItems != after.CustomMenuItems {
|
||||
changed = append(changed, "custom_menu_items")
|
||||
}
|
||||
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
|
||||
changed = append(changed, "enable_fingerprint_unification")
|
||||
}
|
||||
if before.EnableMetadataPassthrough != after.EnableMetadataPassthrough {
|
||||
changed = append(changed, "enable_metadata_passthrough")
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
@@ -780,7 +917,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
@@ -796,18 +933,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
|
||||
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
|
||||
|
||||
var savedConfig *service.SMTPConfig
|
||||
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
|
||||
savedConfig = cfg
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
if req.SMTPHost == "" && savedConfig != nil {
|
||||
req.SMTPHost = savedConfig.Host
|
||||
}
|
||||
if req.SMTPPort <= 0 {
|
||||
if savedConfig != nil && savedConfig.Port > 0 {
|
||||
req.SMTPPort = savedConfig.Port
|
||||
} else {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
}
|
||||
if req.SMTPUsername == "" && savedConfig != nil {
|
||||
req.SMTPUsername = savedConfig.Username
|
||||
}
|
||||
password := strings.TrimSpace(req.SMTPPassword)
|
||||
if password == "" && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
if req.SMTPHost == "" {
|
||||
response.BadRequest(c, "SMTP host is required")
|
||||
return
|
||||
}
|
||||
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
@@ -829,7 +983,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
@@ -847,18 +1001,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
|
||||
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
|
||||
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
|
||||
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
|
||||
|
||||
var savedConfig *service.SMTPConfig
|
||||
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
|
||||
savedConfig = cfg
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
if req.SMTPHost == "" && savedConfig != nil {
|
||||
req.SMTPHost = savedConfig.Host
|
||||
}
|
||||
if req.SMTPPort <= 0 {
|
||||
if savedConfig != nil && savedConfig.Port > 0 {
|
||||
req.SMTPPort = savedConfig.Port
|
||||
} else {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
}
|
||||
if req.SMTPUsername == "" && savedConfig != nil {
|
||||
req.SMTPUsername = savedConfig.Username
|
||||
}
|
||||
password := strings.TrimSpace(req.SMTPPassword)
|
||||
if password == "" && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
if req.SMTPFrom == "" && savedConfig != nil {
|
||||
req.SMTPFrom = savedConfig.From
|
||||
}
|
||||
if req.SMTPFromName == "" && savedConfig != nil {
|
||||
req.SMTPFromName = savedConfig.FromName
|
||||
}
|
||||
if req.SMTPHost == "" {
|
||||
response.BadRequest(c, "SMTP host is required")
|
||||
return
|
||||
}
|
||||
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
@@ -952,6 +1131,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||
}
|
||||
|
||||
// GetOverloadCooldownSettings 获取529过载冷却配置
|
||||
// GET /api/v1/admin/settings/overload-cooldown
|
||||
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.OverloadCooldownSettings{
|
||||
Enabled: settings.Enabled,
|
||||
CooldownMinutes: settings.CooldownMinutes,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
|
||||
type UpdateOverloadCooldownSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
}
|
||||
|
||||
// UpdateOverloadCooldownSettings 更新529过载冷却配置
|
||||
// PUT /api/v1/admin/settings/overload-cooldown
|
||||
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
|
||||
var req UpdateOverloadCooldownSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.OverloadCooldownSettings{
|
||||
Enabled: req.Enabled,
|
||||
CooldownMinutes: req.CooldownMinutes,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.OverloadCooldownSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
CooldownMinutes: updatedSettings.CooldownMinutes,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||
// GET /api/v1/admin/settings/stream-timeout
|
||||
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
@@ -970,384 +1201,6 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
|
||||
if settings == nil {
|
||||
return dto.SoraS3Settings{}
|
||||
}
|
||||
return dto.SoraS3Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Endpoint: settings.Endpoint,
|
||||
Region: settings.Region,
|
||||
Bucket: settings.Bucket,
|
||||
AccessKeyID: settings.AccessKeyID,
|
||||
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
|
||||
Prefix: settings.Prefix,
|
||||
ForcePathStyle: settings.ForcePathStyle,
|
||||
CDNURL: settings.CDNURL,
|
||||
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
|
||||
return dto.SoraS3Profile{
|
||||
ProfileID: profile.ProfileID,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
Enabled: profile.Enabled,
|
||||
Endpoint: profile.Endpoint,
|
||||
Region: profile.Region,
|
||||
Bucket: profile.Bucket,
|
||||
AccessKeyID: profile.AccessKeyID,
|
||||
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
|
||||
Prefix: profile.Prefix,
|
||||
ForcePathStyle: profile.ForcePathStyle,
|
||||
CDNURL: profile.CDNURL,
|
||||
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(endpoint) == "" {
|
||||
return fmt.Errorf("S3 Endpoint is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(bucket) == "" {
|
||||
return fmt.Errorf("S3 Bucket is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(accessKeyID) == "" {
|
||||
return fmt.Errorf("S3 Access Key ID is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("S3 Secret Access Key is required when enabled")
|
||||
}
|
||||
|
||||
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// GET /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(settings))
|
||||
}
|
||||
|
||||
// ListSoraS3Profiles 获取 Sora S3 多配置
|
||||
// GET /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
|
||||
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
items := make([]dto.SoraS3Profile, 0, len(result.Items))
|
||||
for idx := range result.Items {
|
||||
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
|
||||
}
|
||||
response.Success(c, dto.ListSoraS3ProfilesResponse{
|
||||
ActiveProfileID: result.ActiveProfileID,
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
|
||||
type UpdateSoraS3SettingsRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type CreateSoraS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
SetActive bool `json:"set_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type UpdateSoraS3ProfileRequest struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
|
||||
var req CreateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.ProfileID) == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}, req.SetActive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*created))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
|
||||
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
existing := findSoraS3ProfileByID(existingList.Items, profileID)
|
||||
if existing == nil {
|
||||
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.ErrorFrom(c, updateErr)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*updated))
|
||||
}
|
||||
|
||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
|
||||
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3ProfileDTO(*active))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// PUT /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.SoraS3Settings{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}
|
||||
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
|
||||
}
|
||||
|
||||
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
|
||||
// POST /api/v1/admin/settings/sora-s3/test
|
||||
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
if h.soraS3Storage == nil {
|
||||
response.Error(c, 500, "S3 存储服务未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Enabled {
|
||||
response.BadRequest(c, "S3 未启用,无法测试连接")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SecretAccessKey == "" {
|
||||
if req.ProfileID != "" {
|
||||
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err == nil {
|
||||
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
|
||||
if profile != nil {
|
||||
req.SecretAccessKey = profile.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.SecretAccessKey == "" {
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err == nil {
|
||||
req.SecretAccessKey = existing.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCfg := &service.SoraS3Settings{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
}
|
||||
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
|
||||
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
@@ -1357,18 +1210,26 @@ func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
patterns := settings.APIKeySignaturePatterns
|
||||
if patterns == nil {
|
||||
patterns = []string{}
|
||||
}
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
APIKeySignatureEnabled: settings.APIKeySignatureEnabled,
|
||||
APIKeySignaturePatterns: patterns,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
APIKeySignatureEnabled bool `json:"apikey_signature_enabled"`
|
||||
APIKeySignaturePatterns []string `json:"apikey_signature_patterns"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
@@ -1380,10 +1241,32 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 校验并清理自定义匹配关键词
|
||||
const maxPatterns = 50
|
||||
const maxPatternLen = 500
|
||||
if len(req.APIKeySignaturePatterns) > maxPatterns {
|
||||
response.BadRequest(c, "Too many signature patterns (max 50)")
|
||||
return
|
||||
}
|
||||
var cleanedPatterns []string
|
||||
for _, p := range req.APIKeySignaturePatterns {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if len(p) > maxPatternLen {
|
||||
response.BadRequest(c, "Signature pattern too long (max 500 characters)")
|
||||
return
|
||||
}
|
||||
cleanedPatterns = append(cleanedPatterns, p)
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
APIKeySignatureEnabled: req.APIKeySignatureEnabled,
|
||||
APIKeySignaturePatterns: cleanedPatterns,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -1398,10 +1281,16 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
updatedPatterns := updatedSettings.APIKeySignaturePatterns
|
||||
if updatedPatterns == nil {
|
||||
updatedPatterns = []string{}
|
||||
}
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
APIKeySignatureEnabled: updatedSettings.APIKeySignatureEnabled,
|
||||
APIKeySignaturePatterns: updatedPatterns,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
status := c.Query("status")
|
||||
platform := c.Query("platform")
|
||||
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -218,11 +219,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Monthly bool `json:"monthly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -235,11 +237,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly {
|
||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfileHandler 处理 TLS 指纹模板的 HTTP 请求
|
||||
type TLSFingerprintProfileHandler struct {
|
||||
service *service.TLSFingerprintProfileService
|
||||
}
|
||||
|
||||
// NewTLSFingerprintProfileHandler 创建 TLS 指纹模板处理器
|
||||
func NewTLSFingerprintProfileHandler(service *service.TLSFingerprintProfileService) *TLSFingerprintProfileHandler {
|
||||
return &TLSFingerprintProfileHandler{service: service}
|
||||
}
|
||||
|
||||
// CreateTLSFingerprintProfileRequest 创建模板请求
|
||||
type CreateTLSFingerprintProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description *string `json:"description"`
|
||||
EnableGREASE *bool `json:"enable_grease"`
|
||||
CipherSuites []uint16 `json:"cipher_suites"`
|
||||
Curves []uint16 `json:"curves"`
|
||||
PointFormats []uint16 `json:"point_formats"`
|
||||
SignatureAlgorithms []uint16 `json:"signature_algorithms"`
|
||||
ALPNProtocols []string `json:"alpn_protocols"`
|
||||
SupportedVersions []uint16 `json:"supported_versions"`
|
||||
KeyShareGroups []uint16 `json:"key_share_groups"`
|
||||
PSKModes []uint16 `json:"psk_modes"`
|
||||
Extensions []uint16 `json:"extensions"`
|
||||
}
|
||||
|
||||
// UpdateTLSFingerprintProfileRequest 更新模板请求(部分更新)
|
||||
type UpdateTLSFingerprintProfileRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
EnableGREASE *bool `json:"enable_grease"`
|
||||
CipherSuites []uint16 `json:"cipher_suites"`
|
||||
Curves []uint16 `json:"curves"`
|
||||
PointFormats []uint16 `json:"point_formats"`
|
||||
SignatureAlgorithms []uint16 `json:"signature_algorithms"`
|
||||
ALPNProtocols []string `json:"alpn_protocols"`
|
||||
SupportedVersions []uint16 `json:"supported_versions"`
|
||||
KeyShareGroups []uint16 `json:"key_share_groups"`
|
||||
PSKModes []uint16 `json:"psk_modes"`
|
||||
Extensions []uint16 `json:"extensions"`
|
||||
}
|
||||
|
||||
// List 获取所有模板
|
||||
// GET /api/v1/admin/tls-fingerprint-profiles
|
||||
func (h *TLSFingerprintProfileHandler) List(c *gin.Context) {
|
||||
profiles, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profiles)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取模板
|
||||
// GET /api/v1/admin/tls-fingerprint-profiles/:id
|
||||
func (h *TLSFingerprintProfileHandler) GetByID(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid profile ID")
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if profile == nil {
|
||||
response.NotFound(c, "Profile not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
// Create 创建模板
|
||||
// POST /api/v1/admin/tls-fingerprint-profiles
|
||||
func (h *TLSFingerprintProfileHandler) Create(c *gin.Context) {
|
||||
var req CreateTLSFingerprintProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
profile := &model.TLSFingerprintProfile{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
CipherSuites: req.CipherSuites,
|
||||
Curves: req.Curves,
|
||||
PointFormats: req.PointFormats,
|
||||
SignatureAlgorithms: req.SignatureAlgorithms,
|
||||
ALPNProtocols: req.ALPNProtocols,
|
||||
SupportedVersions: req.SupportedVersions,
|
||||
KeyShareGroups: req.KeyShareGroups,
|
||||
PSKModes: req.PSKModes,
|
||||
Extensions: req.Extensions,
|
||||
}
|
||||
|
||||
if req.EnableGREASE != nil {
|
||||
profile.EnableGREASE = *req.EnableGREASE
|
||||
}
|
||||
|
||||
created, err := h.service.Create(c.Request.Context(), profile)
|
||||
if err != nil {
|
||||
if _, ok := err.(*model.ValidationError); ok {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, created)
|
||||
}
|
||||
|
||||
// Update 更新模板(支持部分更新)
|
||||
// PUT /api/v1/admin/tls-fingerprint-profiles/:id
|
||||
func (h *TLSFingerprintProfileHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid profile ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateTLSFingerprintProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
response.NotFound(c, "Profile not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 部分更新
|
||||
profile := &model.TLSFingerprintProfile{
|
||||
ID: id,
|
||||
Name: existing.Name,
|
||||
Description: existing.Description,
|
||||
EnableGREASE: existing.EnableGREASE,
|
||||
CipherSuites: existing.CipherSuites,
|
||||
Curves: existing.Curves,
|
||||
PointFormats: existing.PointFormats,
|
||||
SignatureAlgorithms: existing.SignatureAlgorithms,
|
||||
ALPNProtocols: existing.ALPNProtocols,
|
||||
SupportedVersions: existing.SupportedVersions,
|
||||
KeyShareGroups: existing.KeyShareGroups,
|
||||
PSKModes: existing.PSKModes,
|
||||
Extensions: existing.Extensions,
|
||||
}
|
||||
|
||||
if req.Name != nil {
|
||||
profile.Name = *req.Name
|
||||
}
|
||||
if req.Description != nil {
|
||||
profile.Description = req.Description
|
||||
}
|
||||
if req.EnableGREASE != nil {
|
||||
profile.EnableGREASE = *req.EnableGREASE
|
||||
}
|
||||
if req.CipherSuites != nil {
|
||||
profile.CipherSuites = req.CipherSuites
|
||||
}
|
||||
if req.Curves != nil {
|
||||
profile.Curves = req.Curves
|
||||
}
|
||||
if req.PointFormats != nil {
|
||||
profile.PointFormats = req.PointFormats
|
||||
}
|
||||
if req.SignatureAlgorithms != nil {
|
||||
profile.SignatureAlgorithms = req.SignatureAlgorithms
|
||||
}
|
||||
if req.ALPNProtocols != nil {
|
||||
profile.ALPNProtocols = req.ALPNProtocols
|
||||
}
|
||||
if req.SupportedVersions != nil {
|
||||
profile.SupportedVersions = req.SupportedVersions
|
||||
}
|
||||
if req.KeyShareGroups != nil {
|
||||
profile.KeyShareGroups = req.KeyShareGroups
|
||||
}
|
||||
if req.PSKModes != nil {
|
||||
profile.PSKModes = req.PSKModes
|
||||
}
|
||||
if req.Extensions != nil {
|
||||
profile.Extensions = req.Extensions
|
||||
}
|
||||
|
||||
updated, err := h.service.Update(c.Request.Context(), profile)
|
||||
if err != nil {
|
||||
if _, ok := err.(*model.ValidationError); ok {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// Delete 删除模板
|
||||
// DELETE /api/v1/admin/tls-fingerprint-profiles/:id
|
||||
func (h *TLSFingerprintProfileHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid profile ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Profile deleted successfully"})
|
||||
}
|
||||
@@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
billingMode := strings.TrimSpace(c.Query("billing_mode"))
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
@@ -159,8 +160,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -174,6 +175,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
BillingMode: billingMode,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
ExactTotal: exactTotal,
|
||||
@@ -234,6 +236,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
billingMode := strings.TrimSpace(c.Query("billing_mode"))
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
@@ -285,7 +288,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
@@ -311,6 +315,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
BillingMode: billingMode,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
@@ -34,14 +34,13 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
@@ -57,8 +56,7 @@ type UpdateUserRequest struct {
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@@ -75,6 +73,7 @@ type UpdateBalanceRequest struct {
|
||||
// - role: filter by user role
|
||||
// - search: search in email, username
|
||||
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
|
||||
// - group_name: fuzzy filter by allowed group name
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
@@ -89,6 +88,7 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: search,
|
||||
GroupName: strings.TrimSpace(c.Query("group_name")),
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
if raw, ok := c.GetQuery("include_subscriptions"); ok {
|
||||
@@ -180,14 +180,13 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -214,16 +213,15 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -366,3 +364,35 @@ func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||
"total_recharged": totalRecharged,
|
||||
})
|
||||
}
|
||||
|
||||
// ReplaceGroupRequest represents the request to replace a user's exclusive group
|
||||
type ReplaceGroupRequest struct {
|
||||
OldGroupID int64 `json:"old_group_id" binding:"required,gt=0"`
|
||||
NewGroupID int64 `json:"new_group_id" binding:"required,gt=0"`
|
||||
}
|
||||
|
||||
// ReplaceGroup handles replacing a user's exclusive group
|
||||
// POST /api/v1/admin/users/:id/replace-group
|
||||
func (h *UserHandler) ReplaceGroup(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ReplaceGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.ReplaceUserGroup(c.Request.Context(), userID, req.OldGroupID, req.NewGroupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"migrated_keys": result.MigratedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -59,11 +59,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,14 +133,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
ActiveAccountCount: g.ActiveAccountCount,
|
||||
RateLimitedAccountCount: g.RateLimitedAccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
@@ -170,15 +170,12 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: g.RequireOAuthOnly,
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -250,6 +247,10 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
enabled := true
|
||||
out.EnableTLSFingerprint = &enabled
|
||||
}
|
||||
// TLS指纹模板ID
|
||||
if profileID := a.GetTLSFingerprintProfileID(); profileID > 0 {
|
||||
out.TLSFingerprintProfileID = &profileID
|
||||
}
|
||||
// 会话ID伪装开关
|
||||
if a.IsSessionIDMaskingEnabled() {
|
||||
enabled := true
|
||||
@@ -262,10 +263,18 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
target := a.GetCacheTTLOverrideTarget()
|
||||
out.CacheTTLOverrideTarget = &target
|
||||
}
|
||||
// 自定义 Base URL 中继转发
|
||||
if a.IsCustomBaseURLEnabled() {
|
||||
enabled := true
|
||||
out.CustomBaseURLEnabled = &enabled
|
||||
if customURL := a.GetCustomBaseURL(); customURL != "" {
|
||||
out.CustomBaseURL = &customURL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -274,13 +283,44 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
if a.IsDailyQuotaPeriodExpired() {
|
||||
used = 0
|
||||
}
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
if a.IsWeeklyQuotaPeriodExpired() {
|
||||
used = 0
|
||||
}
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -489,15 +529,21 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
requestType := l.EffectiveRequestType()
|
||||
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
|
||||
requestedModel := l.RequestedModel
|
||||
if requestedModel == "" {
|
||||
requestedModel = l.Model
|
||||
}
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
Model: requestedModel,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
@@ -524,6 +570,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
BillingMode: l.BillingMode,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
@@ -550,6 +597,10 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
}
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
UpstreamModel: l.UpstreamModel,
|
||||
ChannelID: l.ChannelID,
|
||||
ModelMappingChain: l.ModelMappingChain,
|
||||
BillingTier: l.BillingTier,
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -76,10 +77,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
inboundEndpoint := "/v1/chat/completions"
|
||||
upstreamEndpoint := "/v1/responses"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
InboundEndpoint: &inboundEndpoint,
|
||||
UpstreamEndpoint: &upstreamEndpoint,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
@@ -88,12 +93,61 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, userDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_UsesRequestedModelAndKeepsUpstreamAdminOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstreamModel := "claude-sonnet-4-20250514"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_4",
|
||||
Model: upstreamModel,
|
||||
RequestedModel: "claude-sonnet-4",
|
||||
UpstreamModel: &upstreamModel,
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "claude-sonnet-4", userDTO.Model)
|
||||
require.Equal(t, "claude-sonnet-4", adminDTO.Model)
|
||||
|
||||
userJSON, err := json.Marshal(userDTO)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(userJSON), "upstream_model")
|
||||
|
||||
adminJSON, err := json.Marshal(adminDTO)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(adminJSON), `"upstream_model":"claude-sonnet-4-20250514"`)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_legacy",
|
||||
Model: "claude-3",
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "claude-3", userDTO.Model)
|
||||
require.Equal(t, "claude-3", adminDTO.Model)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -15,6 +15,13 @@ type CustomMenuItem struct {
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
// CustomEndpoint represents an admin-configured API endpoint for quick copy.
|
||||
type CustomEndpoint struct {
|
||||
Name string `json:"name"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
@@ -22,6 +29,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
@@ -53,8 +61,8 @@ type SystemSettings struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -78,9 +86,17 @@ type SystemSettings struct {
|
||||
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
|
||||
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
MaxClaudeCodeVersion string `json:"max_claude_code_version"`
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -109,47 +125,16 @@ type PublicSettings struct {
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
||||
type ListSoraS3ProfilesResponse struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
type OverloadCooldownSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
@@ -163,9 +148,11 @@ type StreamTimeoutSettings struct {
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
APIKeySignatureEnabled bool `json:"apikey_signature_enabled"`
|
||||
APIKeySignaturePatterns []string `json:"apikey_signature_patterns"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
@@ -206,3 +193,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomEndpoints(raw string) []CustomEndpoint {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || raw == "[]" {
|
||||
return []CustomEndpoint{}
|
||||
}
|
||||
var items []CustomEndpoint
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return []CustomEndpoint{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
@@ -26,9 +26,7 @@ type AdminUser struct {
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
@@ -84,24 +82,19 @@ type Group struct {
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
|
||||
// 账号过滤控制(仅 OpenAI/Antigravity 平台有效)
|
||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -122,9 +115,11 @@ type AdminGroup struct {
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
|
||||
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
|
||||
|
||||
// 分组排序
|
||||
SortOrder int `json:"sort_order"`
|
||||
@@ -183,7 +178,8 @@ type Account struct {
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
TLSFingerprintProfileID *int64 `json:"tls_fingerprint_profile_id,omitempty"`
|
||||
|
||||
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
|
||||
@@ -195,6 +191,10 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
CustomBaseURLEnabled *bool `json:"custom_base_url_enabled,omitempty"`
|
||||
CustomBaseURL *string `json:"custom_base_url,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
@@ -203,6 +203,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -324,9 +334,13 @@ type UsageLog struct {
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
@@ -365,6 +379,9 @@ type UsageLog struct {
|
||||
// Cache TTL Override 标记
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
|
||||
|
||||
// BillingMode 计费模式:token/image
|
||||
BillingMode *string `json:"billing_mode,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
@@ -377,6 +394,17 @@ type UsageLog struct {
|
||||
type AdminUsageLog struct {
|
||||
UsageLog
|
||||
|
||||
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||
// Omitted when no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
|
||||
// ChannelID 渠道 ID
|
||||
ChannelID *int64 `json:"channel_id,omitempty"`
|
||||
// ModelMappingChain 模型映射链,如 "a→b→c"
|
||||
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
|
||||
// BillingTier 计费层级标签(per_request/image 模式)
|
||||
BillingTier *string `json:"billing_tier,omitempty"`
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
|
||||
171
backend/internal/handler/endpoint.go
Normal file
171
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Canonical inbound / upstream endpoint paths.
|
||||
// All normalization and derivation reference this single set
|
||||
// of constants — add new paths HERE when a new API surface
|
||||
// is introduced.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointGeminiModels = "/v1beta/models"
|
||||
)
|
||||
|
||||
// gin.Context keys used by the middleware and helpers below.
|
||||
const (
|
||||
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Normalization functions
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||
// prefixes like /antigravity, /openai) to its canonical form.
|
||||
//
|
||||
// "/antigravity/v1/messages" → "/v1/messages"
|
||||
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||
func NormalizeInboundEndpoint(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
switch {
|
||||
case strings.Contains(path, EndpointChatCompletions):
|
||||
return EndpointChatCompletions
|
||||
case strings.Contains(path, EndpointMessages):
|
||||
return EndpointMessages
|
||||
case strings.Contains(path, EndpointResponses):
|
||||
return EndpointResponses
|
||||
case strings.Contains(path, EndpointGeminiModels):
|
||||
return EndpointGeminiModels
|
||||
default:
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||
// account platform and the normalized inbound endpoint.
|
||||
//
|
||||
// Platform-specific rules:
|
||||
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||
// such as /v1/responses/compact preserved from the raw URL).
|
||||
// - Anthropic → /v1/messages
|
||||
// - Gemini → /v1beta/models
|
||||
// - Antigravity → /v1/messages (Claude) or gemini (Gemini)
|
||||
// - Antigravity routes may target either Claude or Gemini, so the
|
||||
// inbound endpoint is used to distinguish.
|
||||
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
inbound = strings.TrimSpace(inbound)
|
||||
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
// OpenAI forwards everything to the Responses API.
|
||||
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||
return EndpointResponses + suffix
|
||||
}
|
||||
return EndpointResponses
|
||||
|
||||
case service.PlatformAnthropic:
|
||||
return EndpointMessages
|
||||
|
||||
case service.PlatformGemini:
|
||||
return EndpointGeminiModels
|
||||
|
||||
case service.PlatformAntigravity:
|
||||
// Antigravity accounts serve both Claude and Gemini.
|
||||
if inbound == EndpointGeminiModels {
|
||||
return EndpointGeminiModels
|
||||
}
|
||||
return EndpointMessages
|
||||
}
|
||||
|
||||
// Unknown platform — fall back to inbound.
|
||||
return inbound
|
||||
}
|
||||
|
||||
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||
// Returns "" when there is no meaningful suffix.
|
||||
func responsesSubpathSuffix(rawPath string) string {
|
||||
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||
idx := strings.LastIndex(trimmed, "/responses")
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
suffix := trimmed[idx+len("/responses"):]
|
||||
if suffix == "" || suffix == "/" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(suffix, "/") {
|
||||
return ""
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Middleware
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||
// canonical inbound endpoint in gin.Context so that every handler in
|
||||
// the chain can read it via GetInboundEndpoint.
|
||||
//
|
||||
// Apply this middleware to all gateway route groups.
|
||||
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Context helpers — used by handlers before building
|
||||
// RecordUsageInput / RecordUsageLongContextInput.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||
func GetInboundEndpoint(c *gin.Context) string {
|
||||
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: normalize on the fly.
|
||||
path := ""
|
||||
if c != nil {
|
||||
path = c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
}
|
||||
return NormalizeInboundEndpoint(path)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||
// and the account platform. Handlers call this after scheduling an
|
||||
// account, passing account.Platform.
|
||||
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||
inbound := GetInboundEndpoint(c)
|
||||
rawPath := ""
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
rawPath = c.Request.URL.Path
|
||||
}
|
||||
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||
}
|
||||
155
backend/internal/handler/endpoint_test.go
Normal file
155
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() { gin.SetMode(gin.TestMode) }
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// NormalizeInboundEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
// Direct canonical paths.
|
||||
{"/v1/messages", EndpointMessages},
|
||||
{"/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1beta/models", EndpointGeminiModels},
|
||||
|
||||
// Prefixed paths (antigravity, openai).
|
||||
{"/antigravity/v1/messages", EndpointMessages},
|
||||
{"/openai/v1/responses", EndpointResponses},
|
||||
{"/openai/v1/responses/compact", EndpointResponses},
|
||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||
|
||||
// Gin route patterns with wildcards.
|
||||
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||
{"/v1/responses/*subpath", EndpointResponses},
|
||||
|
||||
// Unknown path is returned as-is.
|
||||
{"/v1/embeddings", "/v1/embeddings"},
|
||||
{"", ""},
|
||||
{" /v1/messages ", EndpointMessages},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// DeriveUpstreamEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound string
|
||||
rawPath string
|
||||
platform string
|
||||
want string
|
||||
}{
|
||||
// Anthropic.
|
||||
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||
|
||||
// Gemini.
|
||||
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||
|
||||
// OpenAI — always /v1/responses.
|
||||
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||
|
||||
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||
|
||||
// Unknown platform — passthrough.
|
||||
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// responsesSubpathSuffix
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{"/v1/responses", ""},
|
||||
{"/v1/responses/", ""},
|
||||
{"/v1/responses/compact", "/compact"},
|
||||
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||
{"/v1/messages", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.raw, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// InboundEndpointMiddleware + context helpers
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(InboundEndpointMiddleware())
|
||||
|
||||
var captured string
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
captured = GetInboundEndpoint(c)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, EndpointMessages, captured)
|
||||
}
|
||||
|
||||
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||
|
||||
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||
got := GetInboundEndpoint(c)
|
||||
require.Equal(t, EndpointMessages, got)
|
||||
}
|
||||
|
||||
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||
|
||||
// Simulate middleware.
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, "/v1/responses/compact", got)
|
||||
}
|
||||
@@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqStream := parsedReq.Stream
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
@@ -178,6 +181,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@@ -291,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -391,6 +395,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -402,6 +408,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -414,11 +425,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
forwardFailedFields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("account_platform", account.Platform),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if account.Proxy != nil {
|
||||
forwardFailedFields = append(forwardFailedFields,
|
||||
zap.Int64("proxy_id", account.Proxy.ID),
|
||||
zap.String("proxy_name", account.Proxy.Name),
|
||||
zap.String("proxy_host", account.Proxy.Host),
|
||||
zap.Int("proxy_port", account.Proxy.Port),
|
||||
)
|
||||
} else if account.ProxyID != nil {
|
||||
forwardFailedFields = append(forwardFailedFields, zap.Int64p("proxy_id", account.ProxyID))
|
||||
}
|
||||
reqLog.Error("gateway.forward_failed", forwardFailedFields...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -434,19 +458,30 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -483,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -629,12 +664,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
parsedReq.OnUpstreamAccepted = queueRelease
|
||||
// ===== 用户消息串行队列 END =====
|
||||
|
||||
// 应用渠道模型映射到请求
|
||||
if channelMapping.Mapped {
|
||||
parsedReq.Model = channelMapping.MappedModel
|
||||
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
|
||||
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -704,6 +748,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -716,11 +765,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
forwardFailedFields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("account_platform", account.Platform),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if account.Proxy != nil {
|
||||
forwardFailedFields = append(forwardFailedFields,
|
||||
zap.Int64("proxy_id", account.Proxy.ID),
|
||||
zap.String("proxy_name", account.Proxy.Name),
|
||||
zap.String("proxy_host", account.Proxy.Host),
|
||||
zap.Int("proxy_port", account.Proxy.Port),
|
||||
)
|
||||
} else if account.ProxyID != nil {
|
||||
forwardFailedFields = append(forwardFailedFields, zap.Int64p("proxy_id", account.ProxyID))
|
||||
}
|
||||
reqLog.Error("gateway.forward_failed", forwardFailedFields...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,19 +798,30 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -786,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
platform = forcedPlatform
|
||||
}
|
||||
|
||||
if platform == service.PlatformSora {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": service.DefaultSoraModels(h.cfg),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get available models from account configurations (without platform filter)
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||
|
||||
@@ -909,7 +974,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
||||
}
|
||||
if s := c.Query("end_date"); s != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
||||
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||
}
|
||||
}
|
||||
return startTime, endTime
|
||||
@@ -1185,6 +1250,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
||||
}
|
||||
}
|
||||
|
||||
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
@@ -1193,6 +1262,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
@@ -1242,7 +1312,7 @@ func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarte
|
||||
return true
|
||||
}
|
||||
|
||||
// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求
|
||||
// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足版本要求
|
||||
// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外
|
||||
func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool {
|
||||
ctx := c.Request.Context()
|
||||
@@ -1255,8 +1325,8 @@ func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
minVersion := h.settingService.GetMinClaudeCodeVersion(ctx)
|
||||
if minVersion == "" {
|
||||
minVersion, maxVersion := h.settingService.GetClaudeCodeVersionBounds(ctx)
|
||||
if minVersion == "" && maxVersion == "" {
|
||||
return true // 未设置,不检查
|
||||
}
|
||||
|
||||
@@ -1267,13 +1337,22 @@ func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if service.CompareVersions(clientVersion, minVersion) < 0 {
|
||||
if minVersion != "" && service.CompareVersions(clientVersion, minVersion) < 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
|
||||
fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code",
|
||||
clientVersion, minVersion))
|
||||
return false
|
||||
}
|
||||
|
||||
if maxVersion != "" && service.CompareVersions(clientVersion, maxVersion) > 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
|
||||
fmt.Sprintf("Your Claude Code version (%s) exceeds the maximum allowed version (%s). "+
|
||||
"Please downgrade: npm install -g @anthropic-ai/claude-code@%s && "+
|
||||
"set CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 to prevent auto-upgrade",
|
||||
clientVersion, maxVersion, maxVersion))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1348,6 +1427,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
297
backend/internal/handler/gateway_handler_chat_completions.go
Normal file
297
backend/internal/handler/gateway_handler_chat_completions.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
|
||||
// POST /v1/chat/completions
|
||||
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
|
||||
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
|
||||
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// Claude Code only restriction
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group is restricted to Claude Code clients (/v1/messages only)")
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
// 1. Acquire user concurrency slot
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.chatCompletionsErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request for session hash
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
|
||||
if parsedReq == nil {
|
||||
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
||||
}
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default:
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 4. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleCCFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
}
|
||||
h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.cc.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Record usage
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.cc.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
|
||||
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
|
||||
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
if streamStarted {
|
||||
return
|
||||
}
|
||||
statusCode := http.StatusBadGateway
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
303
backend/internal/handler/gateway_handler_responses.go
Normal file
303
backend/internal/handler/gateway_handler_responses.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
|
||||
// POST /v1/responses
|
||||
// This converts Responses API requests to Anthropic format, forwards to Anthropic
|
||||
// upstream, and converts responses back to Responses format.
|
||||
func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.responses",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream using gjson (like OpenAI handler)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// Claude Code only restriction:
|
||||
// /v1/responses is never a Claude Code endpoint.
|
||||
// When claude_code_only is enabled, this endpoint is rejected.
|
||||
// The existing service-layer checkClaudeCodeRestriction handles degradation
|
||||
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
|
||||
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group is restricted to Claude Code clients (/v1/messages only)")
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
// 1. Acquire user concurrency slot
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.responsesErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request for session hash
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
|
||||
if parsedReq == nil {
|
||||
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
||||
}
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default:
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 4. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// Can't failover if streaming content already sent
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleResponsesFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
}
|
||||
h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.responses.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Record usage
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.responses.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// responsesErrorResponse writes an error in OpenAI Responses API format.
|
||||
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
|
||||
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
if streamStarted {
|
||||
return // Can't write error after stream started
|
||||
}
|
||||
statusCode := http.StatusBadGateway
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||
|
||||
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||
// 具体验证:
|
||||
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||
|
||||
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||
|
||||
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||
}
|
||||
|
||||
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||
|
||||
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||
|
||||
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx,
|
||||
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||
// c.Writer.Size() 仍为 -1
|
||||
|
||||
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||
require.False(t, guardTriggered,
|
||||
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||
}
|
||||
@@ -75,8 +75,10 @@ func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { r
|
||||
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -139,6 +141,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
@@ -157,6 +160,9 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
nil, // tlsFPProfileService
|
||||
nil, // channelService
|
||||
nil, // resolver
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
|
||||
return []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
|
||||
}`)
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
||||
System: []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
||||
}
|
||||
|
||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
|
||||
"system": []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||
})
|
||||
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||
|
||||
@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
if shouldFallbackGeminiModel(modelName, res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
@@ -182,6 +182,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||
reqModel := modelName // 保存映射前的原始模型名
|
||||
if channelMapping.Mapped {
|
||||
modelName = channelMapping.MappedModel
|
||||
}
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
@@ -352,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
@@ -503,6 +511,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -510,12 +521,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gemini_v1beta.models"),
|
||||
@@ -587,6 +602,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
||||
}
|
||||
}
|
||||
|
||||
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
@@ -663,6 +682,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
return true
|
||||
}
|
||||
if res == nil || res.StatusCode != http.StatusNotFound {
|
||||
return false
|
||||
}
|
||||
return gemini.HasFallbackModel(modelName)
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
|
||||
require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res))
|
||||
}
|
||||
|
||||
func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
|
||||
require.False(t, shouldFallbackGeminiModel("gemini-future-model", res))
|
||||
}
|
||||
|
||||
func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := &service.UpstreamHTTPResult{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}},
|
||||
Body: []byte("insufficient authentication scopes"),
|
||||
}
|
||||
require.True(t, shouldFallbackGeminiModel("gemini-future-model", res))
|
||||
}
|
||||
|
||||
@@ -6,28 +6,31 @@ import (
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
type AdminHandlers struct {
|
||||
Dashboard *admin.DashboardHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
Setting *admin.SettingHandler
|
||||
Ops *admin.OpsHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
Dashboard *admin.DashboardHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
Setting *admin.SettingHandler
|
||||
Ops *admin.OpsHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
Channel *admin.ChannelHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
@@ -42,8 +45,6 @@ type Handlers struct {
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
SoraClient *SoraClientHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
@@ -77,6 +77,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
@@ -181,14 +185,12 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@@ -262,14 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "responses root maps to responses upstream",
|
||||
path: "/v1/responses",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
{
|
||||
name: "responses compact keeps compact suffix",
|
||||
path: "/openai/v1/responses/compact",
|
||||
want: "/v1/responses/compact",
|
||||
},
|
||||
{
|
||||
name: "responses nested suffix preserved",
|
||||
path: "/openai/v1/responses/compact/detail",
|
||||
want: "/v1/responses/compact/detail",
|
||||
},
|
||||
{
|
||||
name: "non responses path uses platform fallback",
|
||||
path: "/v1/messages",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
|
||||
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
|
||||
return fallbackModel
|
||||
}
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
@@ -173,6 +183,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
@@ -273,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Forward request
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
// 应用渠道模型映射到请求体
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -352,18 +371,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -526,11 +550,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
@@ -590,7 +619,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
routingModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
@@ -605,7 +634,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
if defaultModel != "" && defaultModel != routingModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
@@ -653,15 +682,15 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
|
||||
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
|
||||
// 应用渠道模型映射到请求体
|
||||
forwardBody := body
|
||||
if channelMappingMsg.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@@ -732,17 +761,22 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1083,6 +1117,10 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
@@ -1231,14 +1269,18 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
@@ -1250,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||||
// 应用渠道模型映射到 WebSocket 首条消息
|
||||
wsFirstMessage := firstMessage
|
||||
if channelMappingWS.Mapped {
|
||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
@@ -1429,6 +1477,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
||||
}
|
||||
}
|
||||
|
||||
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
@@ -1437,6 +1489,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
|
||||
@@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
|
||||
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
|
||||
})
|
||||
|
||||
t.Run("uses_group_default_on_normal_path", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
|
||||
})
|
||||
|
||||
t.Run("returns_empty_without_group_default", func(t *testing.T) {
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
|
||||
Group: &service.Group{},
|
||||
}, ""))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -26,6 +26,25 @@ const (
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
opsUpstreamModelKey = "ops_upstream_model"
|
||||
opsRequestTypeKey = "ops_request_type"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
opsErrInvalidAPIKey = "invalid_api_key"
|
||||
opsErrAPIKeyRequired = "api_key_required"
|
||||
opsErrInsufficientBalance = "insufficient balance"
|
||||
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -329,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
}
|
||||
}
|
||||
|
||||
// setOpsEndpointContext stores upstream model and request type for ops error logging.
|
||||
// Called by handlers after model mapping and request type determination.
|
||||
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
|
||||
c.Set(opsUpstreamModelKey, upstreamModel)
|
||||
}
|
||||
c.Set(opsRequestTypeKey, requestType)
|
||||
}
|
||||
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
@@ -612,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: stream,
|
||||
Stream: stream,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
|
||||
RequestedModel: modelName,
|
||||
UpstreamModel: func() string {
|
||||
if v, ok := c.Get(opsUpstreamModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestType: func() *int16 {
|
||||
if v, ok := c.Get(opsRequestTypeKey); ok {
|
||||
switch t := v.(type) {
|
||||
case int16:
|
||||
return &t
|
||||
case int:
|
||||
v16 := int16(t)
|
||||
return &v16
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: "upstream",
|
||||
@@ -740,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: stream,
|
||||
Stream: stream,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
|
||||
RequestedModel: modelName,
|
||||
UpstreamModel: func() string {
|
||||
if v, ok := c.Get(opsUpstreamModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestType: func() *int16 {
|
||||
if v, ok := c.Get(opsRequestTypeKey); ok {
|
||||
switch t := v.(type) {
|
||||
case int16:
|
||||
return &t
|
||||
case int:
|
||||
v16 := int16(t)
|
||||
return &v16
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: phase,
|
||||
@@ -1024,9 +1101,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE":
|
||||
case opsCodeInsufficientBalance:
|
||||
return "billing_error"
|
||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "subscription_error"
|
||||
default:
|
||||
return "api_error"
|
||||
@@ -1038,7 +1115,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
}
|
||||
|
||||
@@ -1057,7 +1134,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@@ -1103,7 +1180,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1197,21 +1274,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if insufficient balance errors should be ignored
|
||||
if settings.IgnoreInsufficientBalanceErrors {
|
||||
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user