mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
185 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
b43ee62947 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
bcb6444f89 | ||
|
|
c2b14693b4 | ||
|
|
92d35409de | ||
|
|
351a08f813 | ||
|
|
a58dc787a9 | ||
|
|
7079edc2d0 | ||
|
|
da89583ccc | ||
|
|
a42a1f08e9 | ||
|
|
ebd5253e22 | ||
|
|
6411645ffc | ||
|
|
c0c322ba16 | ||
|
|
d35c5cd491 | ||
|
|
7a353028e7 | ||
|
|
2d8d3b7857 | ||
|
|
4190293b07 | ||
|
|
421b4c0aff | ||
|
|
cd69a7cb85 | ||
|
|
0c9ba9e86c | ||
|
|
1b4d2a41c9 | ||
|
|
0787d2b47a | ||
|
|
97bf1d85ab | ||
|
|
207a493fab | ||
|
|
1f3f9e131e | ||
|
|
4ddedfaaf9 | ||
|
|
3ebebef95f | ||
|
|
9f7ad47598 | ||
|
|
3c83cd8be2 | ||
|
|
963b3b768c | ||
|
|
f6709fb5d6 | ||
|
|
921599948b | ||
|
|
5df3cafa99 | ||
|
|
1a2143c1fe | ||
|
|
dd25281305 | ||
|
|
49d0301dde | ||
|
|
d90e56eb45 | ||
|
|
838ada8864 | ||
|
|
65a106792a | ||
|
|
ee4bfcbb81 | ||
|
|
a087f089b8 | ||
|
|
afbe8bf001 | ||
|
|
2a3ef0be06 | ||
|
|
3403909354 | ||
|
|
005d0c5f53 | ||
|
|
8aaaeb29cc | ||
|
|
230f8abd04 | ||
|
|
a18bbb5f2f | ||
|
|
60fce4f1dc | ||
|
|
9af65efcdb | ||
|
|
bc194a7d8c | ||
|
|
c28f691f32 | ||
|
|
ff1f114989 | ||
|
|
cac230206d | ||
|
|
79ae15d5e8 | ||
|
|
0cce0a8877 | ||
|
|
225fd035ae | ||
|
|
fb7d1346b5 | ||
|
|
491a744481 | ||
|
|
f366026435 | ||
|
|
1a0d4ed668 | ||
|
|
63a8c76946 | ||
|
|
f355a68bc9 | ||
|
|
c87e6526c1 | ||
|
|
af3a5076d6 | ||
|
|
18f2e21414 | ||
|
|
8a8cdeebb4 | ||
|
|
12b33f4ea4 | ||
|
|
01b3a09d7d | ||
|
|
0d6c1c7790 | ||
|
|
95e366b6c6 | ||
|
|
77701143bf | ||
|
|
02dea7b09b | ||
|
|
c26f93c4a0 | ||
|
|
c826ac28ef | ||
|
|
1893b0eb30 | ||
|
|
05527b13db | ||
|
|
ae5d9c8bfc | ||
|
|
9117c2a4ec | ||
|
|
bab4bb9904 | ||
|
|
33bae6f49b | ||
|
|
32d619a56b | ||
|
|
642432cf2a | ||
|
|
61e9598b08 | ||
|
|
d4e34c7514 | ||
|
|
bfe7a5e452 | ||
|
|
77d916ffec | ||
|
|
831abf7977 | ||
|
|
817a491087 | ||
|
|
9a8dacc514 | ||
|
|
8adf80d98b | ||
|
|
62686a6213 | ||
|
|
3a089242f8 | ||
|
|
9d70c38504 | ||
|
|
aeb464f3ca | ||
|
|
7076717b20 | ||
|
|
c0a4fcea0a | ||
|
|
aa2b195c86 | ||
|
|
1d0872e7ca | ||
|
|
33988637b5 | ||
|
|
d4f6ad7225 | ||
|
|
078fefed03 | ||
|
|
5b10af85b4 | ||
|
|
4caf95e5dd | ||
|
|
8e1bcf53bb | ||
|
|
064f9be7e4 | ||
|
|
adcfb44cb7 | ||
|
|
3d79773ba2 | ||
|
|
6aa8cbbf20 |
6
.github/workflows/backend-ci.yml
vendored
6
.github/workflows/backend-ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@@ -38,10 +38,10 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
|
||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
28
README.md
28
README.md
@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
||||
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||
- Creates `.env` file with auto-generated secrets
|
||||
- Creates data directories (uses local directories for easy backup/migration)
|
||||
@@ -522,6 +522,28 @@ sub2api/
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
> **Please read carefully before using this project:**
|
||||
>
|
||||
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||
>
|
||||
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
28
README_CN.md
28
README_CN.md
@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
||||
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||
- 创建 `.env` 文件并填充自动生成的密钥
|
||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||
@@ -588,6 +588,28 @@ sub2api/
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 免责声明
|
||||
|
||||
> **使用本项目前请仔细阅读:**
|
||||
>
|
||||
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||
>
|
||||
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -93,20 +93,13 @@ linters:
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
# Temporarily disable style checks to allow CI to pass
|
||||
# "all" enables every SA/ST/S/QF check; only list the ones to disable.
|
||||
checks:
|
||||
- all
|
||||
- -ST1000 # Package comment format
|
||||
@@ -114,489 +107,19 @@ linters:
|
||||
- -ST1020 # Comment on exported method format
|
||||
- -ST1021 # Comment on exported type format
|
||||
- -ST1022 # Comment on exported variable format
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# NOTE: ST1020, ST1021, ST1022 removed (disabled above)
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
field-writes-are-uses: true
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
exported-fields-are-used: true
|
||||
# Default: true
|
||||
parameters-are-used: true
|
||||
# Default: true
|
||||
local-variables-are-used: false
|
||||
# Default: true — must be true, ent generates 130K+ lines of code
|
||||
generated-is-used: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -86,6 +86,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -216,6 +217,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -162,9 +162,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
||||
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -273,6 +278,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -402,6 +408,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -41,6 +41,8 @@ type Account struct {
|
||||
ProxyID *int64 `json:"proxy_id,omitempty"`
|
||||
// Concurrency holds the value of the "concurrency" field.
|
||||
Concurrency int `json:"concurrency,omitempty"`
|
||||
// LoadFactor holds the value of the "load_factor" field.
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
// Priority holds the value of the "priority" field.
|
||||
Priority int `json:"priority,omitempty"`
|
||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||
@@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case account.FieldRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Concurrency = int(value.Int64)
|
||||
}
|
||||
case account.FieldLoadFactor:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field load_factor", values[i])
|
||||
} else if value.Valid {
|
||||
_m.LoadFactor = new(int)
|
||||
*_m.LoadFactor = int(value.Int64)
|
||||
}
|
||||
case account.FieldPriority:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field priority", values[i])
|
||||
@@ -445,6 +454,11 @@ func (_m *Account) String() string {
|
||||
builder.WriteString("concurrency=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Concurrency))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.LoadFactor; v != nil {
|
||||
builder.WriteString("load_factor=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("priority=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -37,6 +37,8 @@ const (
|
||||
FieldProxyID = "proxy_id"
|
||||
// FieldConcurrency holds the string denoting the concurrency field in the database.
|
||||
FieldConcurrency = "concurrency"
|
||||
// FieldLoadFactor holds the string denoting the load_factor field in the database.
|
||||
FieldLoadFactor = "load_factor"
|
||||
// FieldPriority holds the string denoting the priority field in the database.
|
||||
FieldPriority = "priority"
|
||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||
@@ -121,6 +123,7 @@ var Columns = []string{
|
||||
FieldExtra,
|
||||
FieldProxyID,
|
||||
FieldConcurrency,
|
||||
FieldLoadFactor,
|
||||
FieldPriority,
|
||||
FieldRateMultiplier,
|
||||
FieldStatus,
|
||||
@@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldConcurrency, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByLoadFactor orders the results by the load_factor field.
|
||||
func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldLoadFactor, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPriority orders the results by the priority field.
|
||||
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||
|
||||
@@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ.
|
||||
func LoadFactor(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
|
||||
func Priority(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
@@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactorEQ applies the EQ predicate on the "load_factor" field.
|
||||
func LoadFactorEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field.
|
||||
func LoadFactorNEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIn applies the In predicate on the "load_factor" field.
|
||||
func LoadFactorIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field.
|
||||
func LoadFactorNotIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorGT applies the GT predicate on the "load_factor" field.
|
||||
func LoadFactorGT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorGTE applies the GTE predicate on the "load_factor" field.
|
||||
func LoadFactorGTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLT applies the LT predicate on the "load_factor" field.
|
||||
func LoadFactorLT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLTE applies the LTE predicate on the "load_factor" field.
|
||||
func LoadFactorLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field.
|
||||
func LoadFactorIsNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldIsNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field.
|
||||
func LoadFactorNotNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldNotNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// PriorityEQ applies the EQ predicate on the "priority" field.
|
||||
func PriorityEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
|
||||
@@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate {
|
||||
_c.mutation.SetLoadFactor(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetLoadFactor(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_c *AccountCreate) SetPriority(v int) *AccountCreate {
|
||||
_c.mutation.SetPriority(v)
|
||||
@@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
|
||||
_node.Concurrency = value
|
||||
}
|
||||
if value, ok := _c.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
_node.LoadFactor = &value
|
||||
}
|
||||
if value, ok := _c.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
_node.Priority = value
|
||||
@@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert {
|
||||
u.Set(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert {
|
||||
u.Add(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert {
|
||||
u.SetNull(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsert) SetPriority(v int) *AccountUpsert {
|
||||
u.Set(account.FieldPriority, v)
|
||||
@@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
@@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
|
||||
@@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
@@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ type Announcement struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
// 状态: draft, active, archived
|
||||
Status string `json:"status,omitempty"`
|
||||
// 通知模式: silent(仅铃铛), popup(弹窗提醒)
|
||||
NotifyMode string `json:"notify_mode,omitempty"`
|
||||
// 展示条件(JSON 规则)
|
||||
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
|
||||
// 开始展示时间(为空表示立即生效)
|
||||
@@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new([]byte)
|
||||
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode:
|
||||
values[i] = new(sql.NullString)
|
||||
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case announcement.FieldNotifyMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field notify_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.NotifyMode = value.String
|
||||
}
|
||||
case announcement.FieldTargeting:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field targeting", values[i])
|
||||
@@ -213,6 +221,9 @@ func (_m *Announcement) String() string {
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("notify_mode=")
|
||||
builder.WriteString(_m.NotifyMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("targeting=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
FieldContent = "content"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldNotifyMode holds the string denoting the notify_mode field in the database.
|
||||
FieldNotifyMode = "notify_mode"
|
||||
// FieldTargeting holds the string denoting the targeting field in the database.
|
||||
FieldTargeting = "targeting"
|
||||
// FieldStartsAt holds the string denoting the starts_at field in the database.
|
||||
@@ -53,6 +55,7 @@ var Columns = []string{
|
||||
FieldTitle,
|
||||
FieldContent,
|
||||
FieldStatus,
|
||||
FieldNotifyMode,
|
||||
FieldTargeting,
|
||||
FieldStartsAt,
|
||||
FieldEndsAt,
|
||||
@@ -81,6 +84,10 @@ var (
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultNotifyMode holds the default value on creation for the "notify_mode" field.
|
||||
DefaultNotifyMode string
|
||||
// NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
NotifyModeValidator func(string) error
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
@@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByNotifyMode orders the results by the notify_mode field.
|
||||
func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotifyMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStartsAt orders the results by the starts_at field.
|
||||
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
|
||||
|
||||
@@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ.
|
||||
func NotifyMode(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
|
||||
func StartsAt(v time.Time) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
|
||||
@@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyModeEQ applies the EQ predicate on the "notify_mode" field.
|
||||
func NotifyModeEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field.
|
||||
func NotifyModeNEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeIn applies the In predicate on the "notify_mode" field.
|
||||
func NotifyModeIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field.
|
||||
func NotifyModeNotIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeGT applies the GT predicate on the "notify_mode" field.
|
||||
func NotifyModeGT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeGTE applies the GTE predicate on the "notify_mode" field.
|
||||
func NotifyModeGTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLT applies the LT predicate on the "notify_mode" field.
|
||||
func NotifyModeLT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLTE applies the LTE predicate on the "notify_mode" field.
|
||||
func NotifyModeLTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContains applies the Contains predicate on the "notify_mode" field.
|
||||
func NotifyModeContains(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasPrefix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasSuffix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field.
|
||||
func NotifyModeEqualFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field.
|
||||
func NotifyModeContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
|
||||
func TargetingIsNil() predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
|
||||
|
||||
@@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate {
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate {
|
||||
if v != nil {
|
||||
_c.SetNotifyMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate {
|
||||
_c.mutation.SetTargeting(v)
|
||||
@@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() {
|
||||
v := announcement.DefaultStatus
|
||||
_c.mutation.SetStatus(v)
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
v := announcement.DefaultNotifyMode
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
v := announcement.DefaultCreatedAt()
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec)
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
_node.Status = value
|
||||
}
|
||||
if value, ok := _c.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
_node.NotifyMode = value
|
||||
}
|
||||
if value, ok := _c.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
_node.Targeting = value
|
||||
@@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldNotifyMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert {
|
||||
u.SetExcluded(announcement.FieldNotifyMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldTargeting, v)
|
||||
@@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
@@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
|
||||
@@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
@@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ type Group struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -186,13 +190,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||
values[i] = new(sql.NullString)
|
||||
case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -415,6 +419,18 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.SortOrder = int(value.Int64)
|
||||
}
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AllowMessagesDispatch = value.Bool
|
||||
}
|
||||
case group.FieldDefaultMappedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DefaultMappedModel = value.String
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -608,6 +624,12 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sort_order=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("allow_messages_dispatch=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -75,6 +75,10 @@ const (
|
||||
FieldSupportedModelScopes = "supported_model_scopes"
|
||||
// FieldSortOrder holds the string denoting the sort_order field in the database.
|
||||
FieldSortOrder = "sort_order"
|
||||
// FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database.
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -180,6 +184,8 @@ var Columns = []string{
|
||||
FieldMcpXMLInject,
|
||||
FieldSupportedModelScopes,
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldDefaultMappedModel,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -247,6 +253,12 @@ var (
|
||||
DefaultSupportedModelScopes []string
|
||||
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
|
||||
DefaultSortOrder int
|
||||
// DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field.
|
||||
DefaultAllowMessagesDispatch bool
|
||||
// DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field.
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
DefaultMappedModelValidator func(string) error
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -397,6 +409,16 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field.
|
||||
func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDefaultMappedModel orders the results by the default_mapped_model field.
|
||||
func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -195,6 +195,16 @@ func SortOrder(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ.
|
||||
func AllowMessagesDispatch(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ.
|
||||
func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1470,6 +1480,81 @@ func SortOrderLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNotIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContains(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasPrefix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasSuffix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEqualFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -424,6 +424,34 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate {
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate {
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -613,6 +641,14 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultSortOrder
|
||||
_c.mutation.SetSortOrder(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
v := group.DefaultAllowMessagesDispatch
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -683,6 +719,17 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.SortOrder(); !ok {
|
||||
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -830,6 +877,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
_node.SortOrder = value
|
||||
}
|
||||
if value, ok := _c.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
_node.AllowMessagesDispatch = value
|
||||
}
|
||||
if value, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1520,6 +1575,30 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldAllowMessagesDispatch, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldAllowMessagesDispatch)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert {
|
||||
u.Set(group.FieldDefaultMappedModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldDefaultMappedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2188,6 +2267,34 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -3022,6 +3129,34 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -625,6 +625,34 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -910,6 +938,11 @@ func (_u *GroupUpdate) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1110,6 +1143,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -2014,6 +2053,34 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2312,6 +2379,11 @@ func (_u *GroupUpdateOne) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2529,6 +2601,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -106,6 +106,7 @@ var (
|
||||
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||
{Name: "load_factor", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
@@ -132,7 +133,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "accounts_proxies_proxy",
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -151,52 +152,52 @@ var (
|
||||
{
|
||||
Name: "account_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_proxy_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_last_used_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[15]},
|
||||
Columns: []*schema.Column{AccountsColumns[16]},
|
||||
},
|
||||
{
|
||||
Name: "account_schedulable",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limited_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limit_reset_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
},
|
||||
{
|
||||
Name: "account_overload_until",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
Columns: []*schema.Column{AccountsColumns[22]},
|
||||
},
|
||||
{
|
||||
Name: "account_platform_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_deleted_at",
|
||||
@@ -250,6 +251,7 @@ var (
|
||||
{Name: "title", Type: field.TypeString, Size: 200},
|
||||
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
|
||||
{Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"},
|
||||
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -272,17 +274,17 @@ var (
|
||||
{
|
||||
Name: "announcement_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[9]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_starts_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[5]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_ends_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[7]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -406,6 +408,8 @@ var (
|
||||
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
|
||||
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -2260,6 +2260,8 @@ type AccountMutation struct {
|
||||
extra *map[string]interface{}
|
||||
concurrency *int
|
||||
addconcurrency *int
|
||||
load_factor *int
|
||||
addload_factor *int
|
||||
priority *int
|
||||
addpriority *int
|
||||
rate_multiplier *float64
|
||||
@@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() {
|
||||
m.addconcurrency = nil
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (m *AccountMutation) SetLoadFactor(i int) {
|
||||
m.load_factor = &i
|
||||
m.addload_factor = nil
|
||||
}
|
||||
|
||||
// LoadFactor returns the value of the "load_factor" field in the mutation.
|
||||
func (m *AccountMutation) LoadFactor() (r int, exists bool) {
|
||||
v := m.load_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldLoadFactor returns the old "load_factor" field's value of the Account entity.
|
||||
// If the Account object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldLoadFactor requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err)
|
||||
}
|
||||
return oldValue.LoadFactor, nil
|
||||
}
|
||||
|
||||
// AddLoadFactor adds i to the "load_factor" field.
|
||||
func (m *AccountMutation) AddLoadFactor(i int) {
|
||||
if m.addload_factor != nil {
|
||||
*m.addload_factor += i
|
||||
} else {
|
||||
m.addload_factor = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation.
|
||||
func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) {
|
||||
v := m.addload_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (m *AccountMutation) ClearLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
m.clearedFields[account.FieldLoadFactor] = struct{}{}
|
||||
}
|
||||
|
||||
// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation.
|
||||
func (m *AccountMutation) LoadFactorCleared() bool {
|
||||
_, ok := m.clearedFields[account.FieldLoadFactor]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetLoadFactor resets all changes to the "load_factor" field.
|
||||
func (m *AccountMutation) ResetLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
delete(m.clearedFields, account.FieldLoadFactor)
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (m *AccountMutation) SetPriority(i int) {
|
||||
m.priority = &i
|
||||
@@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AccountMutation) Fields() []string {
|
||||
fields := make([]string, 0, 27)
|
||||
fields := make([]string, 0, 28)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, account.FieldCreatedAt)
|
||||
}
|
||||
@@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string {
|
||||
if m.concurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.load_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.priority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ProxyID()
|
||||
case account.FieldConcurrency:
|
||||
return m.Concurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.LoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.Priority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
|
||||
return m.OldProxyID(ctx)
|
||||
case account.FieldConcurrency:
|
||||
return m.OldConcurrency(ctx)
|
||||
case account.FieldLoadFactor:
|
||||
return m.OldLoadFactor(ctx)
|
||||
case account.FieldPriority:
|
||||
return m.OldPriority(ctx)
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string {
|
||||
if m.addconcurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.addload_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.addpriority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
|
||||
switch name {
|
||||
case account.FieldConcurrency:
|
||||
return m.AddedConcurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.AddedLoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.AddedPriority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(account.FieldProxyID) {
|
||||
fields = append(fields, account.FieldProxyID)
|
||||
}
|
||||
if m.FieldCleared(account.FieldLoadFactor) {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.FieldCleared(account.FieldErrorMessage) {
|
||||
fields = append(fields, account.FieldErrorMessage)
|
||||
}
|
||||
@@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error {
|
||||
case account.FieldProxyID:
|
||||
m.ClearProxyID()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ClearLoadFactor()
|
||||
return nil
|
||||
case account.FieldErrorMessage:
|
||||
m.ClearErrorMessage()
|
||||
return nil
|
||||
@@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error {
|
||||
case account.FieldConcurrency:
|
||||
m.ResetConcurrency()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ResetLoadFactor()
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
m.ResetPriority()
|
||||
return nil
|
||||
@@ -5060,6 +5167,7 @@ type AnnouncementMutation struct {
|
||||
title *string
|
||||
content *string
|
||||
status *string
|
||||
notify_mode *string
|
||||
targeting *domain.AnnouncementTargeting
|
||||
starts_at *time.Time
|
||||
ends_at *time.Time
|
||||
@@ -5284,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() {
|
||||
m.status = nil
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) SetNotifyMode(s string) {
|
||||
m.notify_mode = &s
|
||||
}
|
||||
|
||||
// NotifyMode returns the value of the "notify_mode" field in the mutation.
|
||||
func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) {
|
||||
v := m.notify_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity.
|
||||
// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldNotifyMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err)
|
||||
}
|
||||
return oldValue.NotifyMode, nil
|
||||
}
|
||||
|
||||
// ResetNotifyMode resets all changes to the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) ResetNotifyMode() {
|
||||
m.notify_mode = nil
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) {
|
||||
m.targeting = &dt
|
||||
@@ -5731,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AnnouncementMutation) Fields() []string {
|
||||
fields := make([]string, 0, 10)
|
||||
fields := make([]string, 0, 11)
|
||||
if m.title != nil {
|
||||
fields = append(fields, announcement.FieldTitle)
|
||||
}
|
||||
@@ -5741,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string {
|
||||
if m.status != nil {
|
||||
fields = append(fields, announcement.FieldStatus)
|
||||
}
|
||||
if m.notify_mode != nil {
|
||||
fields = append(fields, announcement.FieldNotifyMode)
|
||||
}
|
||||
if m.targeting != nil {
|
||||
fields = append(fields, announcement.FieldTargeting)
|
||||
}
|
||||
@@ -5776,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Content()
|
||||
case announcement.FieldStatus:
|
||||
return m.Status()
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.NotifyMode()
|
||||
case announcement.FieldTargeting:
|
||||
return m.Targeting()
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5805,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V
|
||||
return m.OldContent(ctx)
|
||||
case announcement.FieldStatus:
|
||||
return m.OldStatus(ctx)
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.OldNotifyMode(ctx)
|
||||
case announcement.FieldTargeting:
|
||||
return m.OldTargeting(ctx)
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5849,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetStatus(v)
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetNotifyMode(v)
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
v, ok := value.(domain.AnnouncementTargeting)
|
||||
if !ok {
|
||||
@@ -6016,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error {
|
||||
case announcement.FieldStatus:
|
||||
m.ResetStatus()
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
m.ResetNotifyMode()
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
m.ResetTargeting()
|
||||
return nil
|
||||
@@ -8089,6 +8250,8 @@ type GroupMutation struct {
|
||||
appendsupported_model_scopes []string
|
||||
sort_order *int
|
||||
addsort_order *int
|
||||
allow_messages_dispatch *bool
|
||||
default_mapped_model *string
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -9833,6 +9996,78 @@ func (m *GroupMutation) ResetSortOrder() {
|
||||
m.addsort_order = nil
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
|
||||
m.allow_messages_dispatch = &b
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
|
||||
func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
|
||||
v := m.allow_messages_dispatch
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
|
||||
}
|
||||
return oldValue.AllowMessagesDispatch, nil
|
||||
}
|
||||
|
||||
// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) ResetAllowMessagesDispatch() {
|
||||
m.allow_messages_dispatch = nil
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (m *GroupMutation) SetDefaultMappedModel(s string) {
|
||||
m.default_mapped_model = &s
|
||||
}
|
||||
|
||||
// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
|
||||
func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
|
||||
v := m.default_mapped_model
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
|
||||
}
|
||||
return oldValue.DefaultMappedModel, nil
|
||||
}
|
||||
|
||||
// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
|
||||
func (m *GroupMutation) ResetDefaultMappedModel() {
|
||||
m.default_mapped_model = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -10191,7 +10426,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 30)
|
||||
fields := make([]string, 0, 32)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -10282,6 +10517,12 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.sort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
if m.allow_messages_dispatch != nil {
|
||||
fields = append(fields, group.FieldAllowMessagesDispatch)
|
||||
}
|
||||
if m.default_mapped_model != nil {
|
||||
fields = append(fields, group.FieldDefaultMappedModel)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -10350,6 +10591,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.SupportedModelScopes()
|
||||
case group.FieldSortOrder:
|
||||
return m.SortOrder()
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.AllowMessagesDispatch()
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.DefaultMappedModel()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -10419,6 +10664,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldSupportedModelScopes(ctx)
|
||||
case group.FieldSortOrder:
|
||||
return m.OldSortOrder(ctx)
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.OldAllowMessagesDispatch(ctx)
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -10638,6 +10887,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetSortOrder(v)
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAllowMessagesDispatch(v)
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetDefaultMappedModel(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -11065,6 +11328,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldSortOrder:
|
||||
m.ResetSortOrder()
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
m.ResetAllowMessagesDispatch()
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
m.ResetDefaultMappedModel()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -212,29 +212,29 @@ func init() {
|
||||
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
|
||||
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
|
||||
// accountDescPriority is the schema descriptor for priority field.
|
||||
accountDescPriority := accountFields[8].Descriptor()
|
||||
accountDescPriority := accountFields[9].Descriptor()
|
||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||
accountDescRateMultiplier := accountFields[10].Descriptor()
|
||||
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||
// accountDescStatus is the schema descriptor for status field.
|
||||
accountDescStatus := accountFields[10].Descriptor()
|
||||
accountDescStatus := accountFields[11].Descriptor()
|
||||
// account.DefaultStatus holds the default value on creation for the status field.
|
||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||
accountDescAutoPauseOnExpired := accountFields[15].Descriptor()
|
||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||
accountDescSchedulable := accountFields[15].Descriptor()
|
||||
accountDescSchedulable := accountFields[16].Descriptor()
|
||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||
accountDescSessionWindowStatus := accountFields[23].Descriptor()
|
||||
accountDescSessionWindowStatus := accountFields[24].Descriptor()
|
||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||
@@ -277,12 +277,18 @@ func init() {
|
||||
announcement.DefaultStatus = announcementDescStatus.Default.(string)
|
||||
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
|
||||
// announcementDescNotifyMode is the schema descriptor for notify_mode field.
|
||||
announcementDescNotifyMode := announcementFields[3].Descriptor()
|
||||
// announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field.
|
||||
announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string)
|
||||
// announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error)
|
||||
// announcementDescCreatedAt is the schema descriptor for created_at field.
|
||||
announcementDescCreatedAt := announcementFields[8].Descriptor()
|
||||
announcementDescCreatedAt := announcementFields[9].Descriptor()
|
||||
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
|
||||
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
announcementDescUpdatedAt := announcementFields[9].Descriptor()
|
||||
announcementDescUpdatedAt := announcementFields[10].Descriptor()
|
||||
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
|
||||
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
@@ -447,6 +453,16 @@ func init() {
|
||||
groupDescSortOrder := groupFields[26].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
|
||||
@@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field {
|
||||
field.Int("concurrency").
|
||||
Default(3),
|
||||
|
||||
field.Int("load_factor").Optional().Nillable(),
|
||||
|
||||
// priority: 账户优先级,数值越小优先级越高
|
||||
// 调度器会优先使用高优先级的账户
|
||||
field.Int("priority").
|
||||
|
||||
@@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementStatusDraft).
|
||||
Comment("状态: draft, active, archived"),
|
||||
field.String("notify_mode").
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementNotifyModeSilent).
|
||||
Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"),
|
||||
field.JSON("targeting", domain.AnnouncementTargeting{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
|
||||
@@ -148,6 +148,15 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
Default(false).
|
||||
Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"),
|
||||
field.String("default_mapped_model").
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.25.7
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -38,8 +39,6 @@ require (
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.40.0
|
||||
google.golang.org/grpc v1.75.1
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
@@ -53,7 +52,6 @@ require (
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
|
||||
@@ -109,7 +107,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
@@ -169,6 +166,7 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
@@ -178,8 +176,8 @@ require (
|
||||
golang.org/x/mod v0.32.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
@@ -171,8 +171,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -182,8 +180,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -398,8 +394,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||
@@ -455,8 +449,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||
|
||||
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||
// Enabled: 全局总开关(默认 true)
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
@@ -1335,7 +1335,7 @@ func setDefaults() {
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||
@@ -1402,7 +1402,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||
switch mode {
|
||||
case "off", "shared", "dedicated":
|
||||
case "off", "ctx_pool", "passthrough":
|
||||
case "shared", "dedicated":
|
||||
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||
}
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||
|
||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||
},
|
||||
{
|
||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||
},
|
||||
|
||||
@@ -13,6 +13,11 @@ const (
|
||||
AnnouncementStatusArchived = "archived"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = "silent"
|
||||
AnnouncementNotifyModePopup = "popup"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = "subscription"
|
||||
AnnouncementConditionTypeBalance = "balance"
|
||||
@@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error {
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
enrichCredentialsFromIDToken(&item)
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, _ := item.Credentials["id_token"].(string)
|
||||
if strings.TrimSpace(idToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeIDToken skips expiry validation — safe for imported data
|
||||
claims, err := openai.DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := claims.GetUserInfo()
|
||||
if userInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fill missing fields only (never overwrite existing values)
|
||||
setIfMissing := func(key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||
item.Credentials[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
setIfMissing("email", userInfo.Email)
|
||||
setIfMissing("plan_type", userInfo.PlanType)
|
||||
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -102,6 +104,7 @@ type CreateAccountRequest struct {
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -120,7 +123,8 @@ type UpdateAccountRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -135,6 +139,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
@@ -240,77 +245,77 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
if !lite {
|
||||
// Get current concurrency counts for all accounts
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
|
||||
// 始终获取并发数(Redis ZCARD,极低开销)
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
}
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
}
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
}
|
||||
|
||||
// 始终获取 RPM 计数(Redis GET,极低开销)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 始终获取活跃会话数(Redis ZCARD,低开销)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 始终获取窗口费用(PostgreSQL 聚合查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 RPM 计数(批量查询)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
@@ -506,6 +511,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
@@ -575,6 +581,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
@@ -655,6 +662,42 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
@@ -710,52 +753,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -765,10 +787,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -787,37 +808,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
@@ -839,20 +850,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
if warning == "missing_project_id_temporary" {
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
@@ -908,14 +950,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchClearError handles batch clearing account errors
|
||||
// POST /api/v1/admin/accounts/batch-clear-error
|
||||
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, id := range req.AccountIDs {
|
||||
accountID := id // 闭包捕获
|
||||
g.Go(func() error {
|
||||
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": accountID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefresh handles batch refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/batch-refresh
|
||||
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||
foundIDs := make(map[int64]bool, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
foundIDs[acc.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
var warnings []gin.H
|
||||
|
||||
// 将不存在的账号 ID 标记为失败
|
||||
for _, id := range req.AccountIDs {
|
||||
if !foundIDs[id] {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": id,
|
||||
"error": "account not found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
if warning != "" {
|
||||
warnings = append(warnings, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"warning": warning,
|
||||
})
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
@@ -1101,6 +1304,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.RateMultiplier != nil ||
|
||||
req.LoadFactor != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
@@ -1119,6 +1323,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
@@ -1328,6 +1533,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// ResetQuota handles resetting account quota usage
|
||||
// POST /api/v1/admin/accounts/:id/reset-quota
|
||||
func (h *AccountHandler) ResetQuota(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
|
||||
response.InternalError(c, "Failed to reset account quota: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
// GET /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
|
||||
|
||||
@@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
@@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A
|
||||
}
|
||||
|
||||
type CreateAnnouncementRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
}
|
||||
|
||||
type UpdateAnnouncementRequest struct {
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
}
|
||||
|
||||
// List handles listing announcements with filters
|
||||
@@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.CreateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil && *req.StartsAt > 0 {
|
||||
@@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.UpdateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil {
|
||||
|
||||
@@ -53,6 +53,9 @@ type CreateGroupRequest struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -88,6 +91,9 @@ type UpdateGroupRequest struct {
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -203,6 +209,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -254,6 +262,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
163
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
163
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ScheduledTestHandler handles admin scheduled-test-plan management.
|
||||
type ScheduledTestHandler struct {
|
||||
scheduledTestSvc *service.ScheduledTestService
|
||||
}
|
||||
|
||||
// NewScheduledTestHandler creates a new ScheduledTestHandler.
|
||||
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
|
||||
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
|
||||
}
|
||||
|
||||
type createScheduledTestPlanRequest struct {
|
||||
AccountID int64 `json:"account_id" binding:"required"`
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid account id")
|
||||
return
|
||||
}
|
||||
|
||||
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, plans)
|
||||
}
|
||||
|
||||
// Create POST /admin/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
var req createScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
plan := &service.ScheduledTestPlan{
|
||||
AccountID: req.AccountID,
|
||||
ModelID: req.ModelID,
|
||||
CronExpression: req.CronExpression,
|
||||
Enabled: true,
|
||||
MaxResults: req.MaxResults,
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// Update PUT /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "plan not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelID != "" {
|
||||
existing.ModelID = req.ModelID
|
||||
}
|
||||
if req.CronExpression != "" {
|
||||
existing.CronExpression = req.CronExpression
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
existing.Enabled = *req.Enabled
|
||||
}
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// Delete DELETE /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// ListResults GET /admin/scheduled-test-plans/:id/results
|
||||
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
@@ -819,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -905,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1348,6 +1348,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
// GET /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||
for i, r := range settings.Rules {
|
||||
rules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||
type UpdateBetaPolicySettingsRequest struct {
|
||||
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||
// PUT /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||
var req UpdateBetaPolicySettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||
for i, r := range req.Rules {
|
||||
rules[i] = service.BetaPolicyRule(r)
|
||||
}
|
||||
|
||||
settings := &service.BetaPolicySettings{Rules: rules}
|
||||
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to return updated settings
|
||||
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||
for i, r := range updated.Rules {
|
||||
outRules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||
if tokenErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||
return
|
||||
}
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", "invitation_required")
|
||||
fragment.Set("pending_oauth_token", pendingToken)
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return
|
||||
}
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
// the invitation code and creating the user account.
|
||||
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
var req completeLinuxDoOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type Announcement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
|
||||
@@ -25,9 +26,10 @@ type Announcement struct {
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
@@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement {
|
||||
return nil
|
||||
}
|
||||
return &Announcement{
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
NotifyMode: a.NotifyMode,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement
|
||||
return nil
|
||||
}
|
||||
return &UserAnnouncement{
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
NotifyMode: a.Announcement.NotifyMode,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -89,15 +89,28 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
RateLimit5h: k.RateLimit5h,
|
||||
RateLimit1d: k.RateLimit1d,
|
||||
RateLimit7d: k.RateLimit7d,
|
||||
Usage5h: k.Usage5h,
|
||||
Usage1d: k.Usage1d,
|
||||
Usage7d: k.Usage7d,
|
||||
Usage5h: k.EffectiveUsage5h(),
|
||||
Usage1d: k.EffectiveUsage1d(),
|
||||
Usage7d: k.EffectiveUsage7d(),
|
||||
Window5hStart: k.Window5hStart,
|
||||
Window1dStart: k.Window1dStart,
|
||||
Window7dStart: k.Window7dStart,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -126,6 +139,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
@@ -164,6 +178,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -183,6 +198,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
LoadFactor: a.LoadFactor,
|
||||
Priority: a.Priority,
|
||||
RateMultiplier: a.BillingRateMultiplier(),
|
||||
Status: a.Status,
|
||||
@@ -248,6 +264,25 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -461,6 +496,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
|
||||
@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -161,6 +161,26 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置 DTO
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -96,6 +99,9 @@ type Group struct {
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -112,6 +118,9 @@ type AdminGroup struct {
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -131,6 +140,7 @@ type Account struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
@@ -185,6 +195,14 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -304,6 +322,8 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -652,6 +652,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
@@ -971,34 +978,46 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
if err == nil && rateLimitData != nil {
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.Usage5h
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.Usage1d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.Usage7d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
@@ -155,6 +156,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -27,6 +27,7 @@ type AdminHandlers struct {
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var handlerStructuredLogCaptureMu sync.Mutex
|
||||
|
||||
type handlerInMemoryLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
cloned := *event
|
||||
if event.Fields != nil {
|
||||
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||
for k, v := range event.Fields {
|
||||
cloned.Fields[k] = v
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.events = append(s.events, &cloned)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
wantLevel := strings.ToLower(strings.TrimSpace(level))
|
||||
for _, ev := range s.events {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev == nil || ev.Fields == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
|
||||
t.Helper()
|
||||
handlerStructuredLogCaptureMu.Lock()
|
||||
|
||||
err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: logger.SamplingOptions{Enabled: false},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sink := &handlerInMemoryLogSink{}
|
||||
logger.SetSink(sink)
|
||||
return sink, func() {
|
||||
logger.SetSink(nil)
|
||||
handlerStructuredLogCaptureMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
|
||||
require.False(t, isOpenAIRemoteCompactPath(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
require.False(t, isOpenAIRemoteCompactPath(c))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||
c.Set(opsAccountIDKey, int64(123))
|
||||
c.Header("x-request-id", "rid-compact-ok")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
|
||||
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Status(http.StatusBadGateway)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -33,6 +34,7 @@ type OpenAIGatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -61,6 +63,7 @@ func NewOpenAIGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +73,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
compactStartedAt := time.Now()
|
||||
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
@@ -114,6 +119,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
sessionHashBody := body
|
||||
if service.IsOpenAIResponsesCompactPathForTest(c) {
|
||||
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
|
||||
c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed)
|
||||
}
|
||||
normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body)
|
||||
if compactErr != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body")
|
||||
return
|
||||
}
|
||||
if normalizedCompact {
|
||||
body = normalizedCompactBody
|
||||
}
|
||||
}
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
@@ -189,11 +208,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -241,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -270,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -301,6 +341,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
@@ -340,6 +383,432 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
return strings.HasSuffix(normalizedPath, "/responses/compact")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
|
||||
if !isOpenAIRemoteCompactPath(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
path string
|
||||
status int
|
||||
)
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
ctx = c.Request.Context()
|
||||
if c.Request.URL != nil {
|
||||
path = strings.TrimSpace(c.Request.URL.Path)
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
status = c.Writer.Status()
|
||||
}
|
||||
}
|
||||
|
||||
outcome := "failed"
|
||||
if status >= 200 && status < 300 {
|
||||
outcome = "succeeded"
|
||||
}
|
||||
latencyMs := time.Since(startedAt).Milliseconds()
|
||||
if latencyMs < 0 {
|
||||
latencyMs = 0
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Bool("remote_compact", true),
|
||||
zap.String("compact_outcome", outcome),
|
||||
zap.Int("status_code", status),
|
||||
zap.Int64("latency_ms", latencyMs),
|
||||
zap.String("path", path),
|
||||
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
|
||||
fields = append(fields, zap.String("request_user_agent", userAgent))
|
||||
}
|
||||
if v, ok := c.Get(opsModelKey); ok {
|
||||
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
|
||||
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(opsAccountIDKey); ok {
|
||||
if accountID, ok := v.(int64); ok && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if outcome == "succeeded" {
|
||||
log.Info("codex.remote_compact.succeeded")
|
||||
return
|
||||
}
|
||||
log.Warn("codex.remote_compact.failed")
|
||||
}
|
||||
|
||||
// Messages handles Anthropic Messages API requests routed to OpenAI platform.
|
||||
// POST /v1/messages (when group platform is OpenAI)
|
||||
func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverAnthropicMessagesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.messages",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查分组是否允许 /v1/messages 调度
|
||||
if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch {
|
||||
h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group does not allow /v1/messages dispatch")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
|
||||
c.Set("openai_messages_fallback_model", "")
|
||||
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_messages.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_messages_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_messages.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_messages.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicErrorResponse writes an error in Anthropic Messages API format.
|
||||
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// anthropicStreamingAwareError handles errors that may occur during streaming,
|
||||
// using Anthropic SSE error format.
|
||||
func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errPayload, _ := json.Marshal(gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format.
|
||||
func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode)
|
||||
h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written.
|
||||
func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
@@ -756,6 +1225,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
@@ -817,6 +1289,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart
|
||||
)
|
||||
}
|
||||
|
||||
// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages
|
||||
// handler and returns an Anthropic-formatted error response.
|
||||
func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := streamStarted != nil && *streamStarted
|
||||
requestLogger(c, "handler.openai_gateway.messages").Error(
|
||||
"openai.messages_panic_recovered",
|
||||
zap.Bool("stream_started", started),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
if !started {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
@@ -1022,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
|
||||
@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
|
||||
|
||||
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
|
||||
@@ -2199,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
@@ -437,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -119,23 +119,33 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// Finish 结束处理,返回最终事件和用量
|
||||
// Finish 结束处理,返回最终事件和用量。
|
||||
// 若整个流未收到任何可解析的上游数据(messageStartSent == false),
|
||||
// 则不补发任何结束事件,防止客户端收到没有 message_start 的残缺流。
|
||||
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
if !p.messageStartSent {
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
}
|
||||
|
||||
// MessageStartSent 报告流中是否已发出过 message_start 事件(即是否收到过有效的上游数据)
|
||||
func (p *StreamingProcessor) MessageStartSent() bool {
|
||||
return p.messageStartSent
|
||||
}
|
||||
|
||||
// emitMessageStart 发送 message_start 事件
|
||||
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
|
||||
if p.messageStartSent {
|
||||
|
||||
1009
backend/internal/pkg/apicompat/anthropic_responses_test.go
Normal file
1009
backend/internal/pkg/apicompat/anthropic_responses_test.go
Normal file
File diff suppressed because it is too large
Load Diff
417
backend/internal/pkg/apicompat/anthropic_to_responses.go
Normal file
417
backend/internal/pkg/apicompat/anthropic_to_responses.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AnthropicToResponses converts an Anthropic Messages request directly into
|
||||
// a Responses API request. This preserves fields that would be lost in a
|
||||
// Chat Completions intermediary round-trip (e.g. thinking, cache_control,
|
||||
// structured system prompts).
|
||||
func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertAnthropicToResponsesInput(req.System, req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
v := req.MaxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: mapAnthropicEffortToResponses(effort),
|
||||
Summary: "auto",
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if len(req.ToolChoice) > 0 {
|
||||
tc, err := convertAnthropicToolChoiceToResponses(req.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert tool_choice: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertAnthropicToolChoiceToResponses maps Anthropic tool_choice to Responses format.
|
||||
//
|
||||
// {"type":"auto"} → "auto"
|
||||
// {"type":"any"} → "required"
|
||||
// {"type":"none"} → "none"
|
||||
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch tc.Type {
|
||||
case "auto":
|
||||
return json.Marshal("auto")
|
||||
case "any":
|
||||
return json.Marshal("required")
|
||||
case "none":
|
||||
return json.Marshal("none")
|
||||
case "tool":
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": tc.Name},
|
||||
})
|
||||
default:
|
||||
// Pass through unknown types as-is
|
||||
return raw, nil
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToResponsesInput builds the Responses API input items array
|
||||
// from the Anthropic system field and message list.
|
||||
func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
|
||||
// System prompt → system role input item.
|
||||
if len(system) > 0 {
|
||||
sysText, err := parseAnthropicSystemPrompt(system)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sysText != "" {
|
||||
content, _ := json.Marshal(sysText)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Role: "system",
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
items, err := anthropicMsgToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseAnthropicSystemPrompt handles the Anthropic system field which can be
|
||||
// a plain string or an array of text blocks.
|
||||
func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parts []string
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n"), nil
|
||||
}
|
||||
|
||||
// anthropicMsgToResponsesItems converts a single Anthropic message into one
|
||||
// or more Responses API input items.
|
||||
func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
return anthropicUserToResponses(m.Content)
|
||||
case "assistant":
|
||||
return anthropicAssistantToResponses(m.Content)
|
||||
default:
|
||||
return anthropicUserToResponses(m.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items. Image blocks are converted to input_image parts.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
var toolResultImageParts []ResponsesContentPart
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
// Images inside tool_results are extracted separately because the
|
||||
// Responses API function_call_output.output only accepts strings.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
outputText, imageParts := convertToolResultOutput(b)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: outputText,
|
||||
})
|
||||
toolResultImageParts = append(toolResultImageParts, imageParts...)
|
||||
}
|
||||
|
||||
// Remaining text + image blocks → user message with content parts.
|
||||
// Also include images extracted from tool_results so the model can see them.
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
if b.Text != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(b.Source); uri != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, toolResultImageParts...)
|
||||
|
||||
if len(parts) > 0 {
|
||||
content, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// anthropicAssistantToResponses handles an Anthropic assistant message.
|
||||
// Text content → assistant message with output_text parts.
|
||||
// tool_use blocks → function_call items.
|
||||
// thinking blocks → ignored (OpenAI doesn't accept them as input).
|
||||
func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Text content → assistant message with output_text content parts.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: text}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
|
||||
// tool_use → function_call items.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_use" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
fcID := toResponsesCallID(b.ID)
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a
|
||||
// Responses API function_call ID that starts with "fc_".
|
||||
func toResponsesCallID(id string) string {
|
||||
if strings.HasPrefix(id, "fc_") {
|
||||
return id
|
||||
}
|
||||
return "fc_" + id
|
||||
}
|
||||
|
||||
// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix
|
||||
// that was added during request conversion.
|
||||
func fromResponsesCallID(id string) string {
|
||||
if after, ok := strings.CutPrefix(id, "fc_"); ok {
|
||||
// Only strip if the remainder doesn't look like it was already "fc_" prefixed.
|
||||
// E.g. "fc_toolu_xxx" → "toolu_xxx", "fc_call_xxx" → "call_xxx"
|
||||
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
|
||||
return after
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
|
||||
// Returns "" if the source is nil or has no data.
|
||||
func anthropicImageToDataURI(src *AnthropicImageSource) string {
|
||||
if src == nil || src.Data == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType := src.MediaType
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + src.Data
|
||||
}
|
||||
|
||||
// convertToolResultOutput extracts text and image content from a tool_result
|
||||
// block. Returns the text as a string for the function_call_output Output
|
||||
// field, plus any image parts that must be sent in a separate user message
|
||||
// (the Responses API output field only accepts strings).
|
||||
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
|
||||
if len(b.Content) == 0 {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Try plain string content.
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
if s == "" {
|
||||
s = "(empty)"
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Array of content blocks — may contain text and/or images.
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err != nil {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Separate text (for function_call_output) from images (for user message).
|
||||
var textParts []string
|
||||
var imageParts []ResponsesContentPart
|
||||
for _, ib := range inner {
|
||||
switch ib.Type {
|
||||
case "text":
|
||||
if ib.Text != "" {
|
||||
textParts = append(textParts, ib.Text)
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
|
||||
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text := strings.Join(textParts, "\n\n")
|
||||
if text == "" {
|
||||
text = "(empty)"
|
||||
}
|
||||
return text, imageParts
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
// tool_use/tool_result blocks.
|
||||
func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
var parts []string
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
for _, t := range tools {
|
||||
// Anthropic server tools like "web_search_20250305" → OpenAI {"type":"web_search"}
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
out = append(out, ResponsesTool{Type: "web_search"})
|
||||
continue
|
||||
}
|
||||
out = append(out, ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.InputSchema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
516
backend/internal/pkg/apicompat/responses_to_anthropic.go
Normal file
516
backend/internal/pkg/apicompat/responses_to_anthropic.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → AnthropicResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToAnthropic converts a Responses API response directly into an
|
||||
// Anthropic Messages response. Reasoning output items are mapped to thinking
|
||||
// blocks; function_call items become tool_use blocks.
|
||||
func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicResponse {
|
||||
out := &AnthropicResponse{
|
||||
ID: resp.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "reasoning":
|
||||
summaryText := ""
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
summaryText += s.Text
|
||||
}
|
||||
}
|
||||
if summaryText != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: summaryText,
|
||||
})
|
||||
}
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
})
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(item.CallID),
|
||||
Name: item.Name,
|
||||
Input: json.RawMessage(item.Arguments),
|
||||
})
|
||||
case "web_search_call":
|
||||
toolUseID := "srvtoolu_" + item.ID
|
||||
query := ""
|
||||
if item.Action != nil {
|
||||
query = item.Action.Query
|
||||
}
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: inputJSON,
|
||||
})
|
||||
emptyResults, _ := json.Marshal([]struct{}{})
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: emptyResults,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
|
||||
}
|
||||
out.Content = blocks
|
||||
|
||||
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
|
||||
|
||||
if resp.Usage != nil {
|
||||
out.Usage = AnthropicUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "max_tokens"
|
||||
}
|
||||
return "end_turn"
|
||||
case "completed":
|
||||
if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" {
|
||||
return "tool_use"
|
||||
}
|
||||
return "end_turn"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToAnthropicState tracks state for converting a sequence of
|
||||
// Responses SSE events directly into Anthropic SSE events.
|
||||
type ResponsesEventToAnthropicState struct {
|
||||
MessageStartSent bool
|
||||
MessageStopSent bool
|
||||
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
|
||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||
OutputIndexToBlockIdx map[int]int
|
||||
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheReadInputTokens int
|
||||
|
||||
ResponseID string
|
||||
Model string
|
||||
Created int64
|
||||
}
|
||||
|
||||
// NewResponsesEventToAnthropicState returns an initialised stream state.
|
||||
func NewResponsesEventToAnthropicState() *ResponsesEventToAnthropicState {
|
||||
return &ResponsesEventToAnthropicState{
|
||||
OutputIndexToBlockIdx: make(map[int]int),
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToAnthropicEvents converts a single Responses SSE event into
|
||||
// zero or more Anthropic SSE events, updating state as it goes.
|
||||
func ResponsesEventToAnthropicEvents(
|
||||
evt *ResponsesStreamEvent,
|
||||
state *ResponsesEventToAnthropicState,
|
||||
) []AnthropicStreamEvent {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToAnthHandleCreated(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToAnthHandleOutputItemAdded(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToAnthHandleTextDelta(evt, state)
|
||||
case "response.output_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToAnthHandleFuncArgsDelta(evt, state)
|
||||
case "response.function_call_arguments.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.output_item.done":
|
||||
return resToAnthHandleOutputItemDone(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToAnthHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToAnthHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesAnthropicStream emits synthetic termination events if the
|
||||
// stream ended without a proper completion event.
|
||||
func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.MessageStartSent || state.MessageStopSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: "end_turn",
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
CacheReadInputTokens: state.CacheReadInputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicStreamEvent{Type: "message_stop"},
|
||||
)
|
||||
state.MessageStopSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
// ResponsesAnthropicEventToSSE formats an AnthropicStreamEvent as an SSE line pair.
|
||||
func ResponsesAnthropicEventToSSE(evt AnthropicStreamEvent) (string, error) {
|
||||
data, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToAnthHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Response != nil {
|
||||
state.ResponseID = evt.Response.ID
|
||||
// Only use upstream model if no override was set (e.g. originalModel)
|
||||
if state.Model == "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
|
||||
if state.MessageStartSent {
|
||||
return nil
|
||||
}
|
||||
state.MessageStartSent = true
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "message_start",
|
||||
Message: &AnthropicResponse{
|
||||
ID: state.ResponseID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: []AnthropicContentBlock{},
|
||||
Model: state.Model,
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch evt.Item.Type {
|
||||
case "function_call":
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "tool_use"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(evt.Item.CallID),
|
||||
Name: evt.Item.Name,
|
||||
Input: json.RawMessage("{}"),
|
||||
},
|
||||
})
|
||||
return events
|
||||
|
||||
case "reasoning":
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "thinking"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
},
|
||||
})
|
||||
return events
|
||||
|
||||
case "message":
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resToAnthHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
|
||||
if !state.ContentBlockOpen || state.CurrentBlockType != "text" {
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "text"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: &idx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "text_delta",
|
||||
Text: evt.Delta,
|
||||
},
|
||||
})
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &blockIdx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: evt.Delta,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &blockIdx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: evt.Delta,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleBlockDone(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.ContentBlockOpen {
|
||||
return nil
|
||||
}
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
|
||||
func resToAnthHandleOutputItemDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle web_search_call → synthesize server_tool_use + web_search_tool_result blocks.
|
||||
if evt.Item.Type == "web_search_call" && evt.Item.Status == "completed" {
|
||||
return resToAnthHandleWebSearchDone(evt, state)
|
||||
}
|
||||
|
||||
if state.ContentBlockOpen {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resToAnthHandleWebSearchDone converts an OpenAI web_search_call output item
|
||||
// into Anthropic server_tool_use + web_search_tool_result content block pairs.
|
||||
// This allows Claude Code to count the searches performed.
|
||||
func resToAnthHandleWebSearchDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
toolUseID := "srvtoolu_" + evt.Item.ID
|
||||
query := ""
|
||||
if evt.Item.Action != nil {
|
||||
query = evt.Item.Action.Query
|
||||
}
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
|
||||
// Emit server_tool_use block (start + stop).
|
||||
idx1 := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx1,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: inputJSON,
|
||||
},
|
||||
})
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx1,
|
||||
})
|
||||
state.ContentBlockIndex++
|
||||
|
||||
// Emit web_search_tool_result block (start + stop).
|
||||
// Content is empty because OpenAI does not expose individual search results;
|
||||
// the model consumes them internally and produces text output.
|
||||
emptyResults, _ := json.Marshal([]struct{}{})
|
||||
idx2 := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx2,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: emptyResults,
|
||||
},
|
||||
})
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx2,
|
||||
})
|
||||
state.ContentBlockIndex++
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if state.MessageStopSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
stopReason := "end_turn"
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
state.InputTokens = evt.Response.Usage.InputTokens
|
||||
state.OutputTokens = evt.Response.Usage.OutputTokens
|
||||
if evt.Response.Usage.InputTokensDetails != nil {
|
||||
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
case "completed":
|
||||
if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
CacheReadInputTokens: state.CacheReadInputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicStreamEvent{Type: "message_stop"},
|
||||
)
|
||||
state.MessageStopSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.ContentBlockOpen {
|
||||
return nil
|
||||
}
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = false
|
||||
state.ContentBlockIndex++
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx,
|
||||
}}
|
||||
}
|
||||
338
backend/internal/pkg/apicompat/types.go
Normal file
338
backend/internal/pkg/apicompat/types.go
Normal file
@@ -0,0 +1,338 @@
|
||||
// Package apicompat provides type definitions and conversion utilities for
|
||||
// translating between Anthropic Messages and OpenAI Responses API formats.
|
||||
// It enables multi-protocol support so that clients using different API
|
||||
// formats can be served through a unified gateway.
|
||||
package apicompat
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic Messages API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicOutputConfig controls output generation parameters.
|
||||
type AnthropicOutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
type AnthropicThinking struct {
|
||||
Type string `json:"type"` // "enabled" | "adaptive" | "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // max thinking tokens
|
||||
}
|
||||
|
||||
// AnthropicMessage is a single message in the Anthropic conversation.
|
||||
type AnthropicMessage struct {
|
||||
Role string `json:"role"` // "user" | "assistant"
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// AnthropicContentBlock is one block inside a message's content array.
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// type=text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=image
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// type=tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []AnthropicContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
}
|
||||
|
||||
// AnthropicResponse is the non-streaming response from POST /v1/messages.
|
||||
type AnthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Content []AnthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// AnthropicUsage holds token counts in Anthropic format.
|
||||
type AnthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic SSE event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicStreamEvent is a single SSE event in the Anthropic streaming protocol.
|
||||
type AnthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// message_start
|
||||
Message *AnthropicResponse `json:"message,omitempty"`
|
||||
|
||||
// content_block_start
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"`
|
||||
|
||||
// content_block_delta
|
||||
Delta *AnthropicDelta `json:"delta,omitempty"`
|
||||
|
||||
// message_delta
|
||||
Usage *AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicDelta carries incremental content in streaming events.
|
||||
type AnthropicDelta struct {
|
||||
Type string `json:"type,omitempty"` // "text_delta" | "input_json_delta" | "thinking_delta" | "signature_delta"
|
||||
|
||||
// text_delta
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// input_json_delta
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
|
||||
// thinking_delta
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// signature_delta
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// message_delta fields
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Responses API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesRequest is the request body for POST /v1/responses.
|
||||
type ResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []ResponsesTool `json:"tools,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
type ResponsesReasoning struct {
|
||||
Effort string `json:"effort"` // "low" | "medium" | "high"
|
||||
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
|
||||
}
|
||||
|
||||
// ResponsesInputItem is one item in the Responses API input array.
|
||||
// The Type field determines which other fields are populated.
|
||||
type ResponsesInputItem struct {
|
||||
// Common
|
||||
Type string `json:"type,omitempty"` // "" for role-based messages
|
||||
|
||||
// Role-based messages (system/user/assistant)
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart
|
||||
|
||||
// type=function_call
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
|
||||
// type=function_call_output
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function" | "web_search" | "local_shell" etc.
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesResponse is the non-streaming response from POST /v1/responses.
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "response"
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"` // "completed" | "incomplete" | "failed"
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
|
||||
// incomplete_details is present when status="incomplete"
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
|
||||
// Error is present when status="failed"
|
||||
Error *ResponsesError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError describes an error in a failed response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails explains why a response is incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"` // "max_output_tokens" | "content_filter"
|
||||
}
|
||||
|
||||
// ResponsesOutput is one output item in a Responses API response.
|
||||
type ResponsesOutput struct {
|
||||
Type string `json:"type"` // "message" | "reasoning" | "function_call" | "web_search_call"
|
||||
|
||||
// type=message
|
||||
ID string `json:"id,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesContentPart `json:"content,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// type=reasoning
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
Summary []ResponsesSummary `json:"summary,omitempty"`
|
||||
|
||||
// type=function_call
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
||||
// type=web_search_call
|
||||
Action *WebSearchAction `json:"action,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchAction describes the search action in a web_search_call output item.
|
||||
type WebSearchAction struct {
|
||||
Type string `json:"type,omitempty"` // "search"
|
||||
Query string `json:"query,omitempty"` // primary search query
|
||||
}
|
||||
|
||||
// ResponsesSummary is a summary text block inside a reasoning output.
|
||||
type ResponsesSummary struct {
|
||||
Type string `json:"type"` // "summary_text"
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ResponsesUsage holds token counts in Responses API format.
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
|
||||
// Optional detailed breakdown
|
||||
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails breaks down input token usage.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputTokensDetails breaks down output token usage.
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Responses SSE event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesStreamEvent is a single SSE event in the Responses streaming protocol.
|
||||
// The Type field corresponds to the "type" in the JSON payload.
|
||||
type ResponsesStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// response.created / response.completed / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
|
||||
// response.output_text.delta / response.output_text.done
|
||||
OutputIndex int `json:"output_index,omitempty"`
|
||||
ContentIndex int `json:"content_index,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ItemID string `json:"item_id,omitempty"`
|
||||
|
||||
// response.function_call_arguments.delta / done
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
||||
// response.reasoning_summary_text.delta / done
|
||||
// Reuses Text/Delta fields above, SummaryIndex identifies which summary part
|
||||
SummaryIndex int `json:"summary_index,omitempty"`
|
||||
|
||||
// error event fields
|
||||
Code string `json:"code,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
|
||||
// Sequence number for ordering events
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// minMaxOutputTokens is the floor for max_output_tokens in a Responses request.
|
||||
// Very small values may cause upstream API errors, so we enforce a minimum.
|
||||
const minMaxOutputTokens = 128
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -15,6 +15,7 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
|
||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
claims, err := DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// UserInfo represents user information extracted from ID Token claims.
|
||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
PlanType string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
|
||||
@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
|
||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
|
||||
// official Codex client family request.
|
||||
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
|
||||
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
|
||||
}
|
||||
|
||||
func normalizeCodexClientHeader(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
originator string
|
||||
want bool
|
||||
}{
|
||||
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
|
||||
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
|
||||
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
|
||||
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,25 +57,28 @@ type DashboardStats struct {
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
|
||||
@@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
} else {
|
||||
builder.ClearLoadFactor()
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -437,6 +445,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
switch status {
|
||||
case "rate_limited":
|
||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||
case "temp_unschedulable":
|
||||
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.And(
|
||||
entsql.Not(entsql.IsNull(col)),
|
||||
entsql.GT(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}))
|
||||
default:
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
}
|
||||
@@ -640,7 +656,14 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
SetStatus(service.StatusActive).
|
||||
SetErrorMessage("").
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
@@ -899,6 +922,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1014,6 +1038,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1205,6 +1230,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.RateMultiplier)
|
||||
idx++
|
||||
}
|
||||
if updates.LoadFactor != nil {
|
||||
if *updates.LoadFactor <= 0 {
|
||||
setClauses = append(setClauses, "load_factor = NULL")
|
||||
} else {
|
||||
setClauses = append(setClauses, "load_factor = $"+itoa(idx))
|
||||
args = append(args, *updates.LoadFactor)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if updates.Status != nil {
|
||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||
args = append(args, *updates.Status)
|
||||
@@ -1527,6 +1561,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
LoadFactor: m.LoadFactor,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
@@ -1639,3 +1674,93 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
-- 总额度:始终递增
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
-- 日额度:仅在 quota_daily_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 任一维度配额刚超限时触发调度快照刷新
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 重置配额后触发调度快照刷新,使账号重新参与调度
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -558,6 +558,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
|
||||
@@ -24,6 +24,7 @@ func (r *announcementRepository) Create(ctx context.Context, a *service.Announce
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
@@ -64,6 +65,7 @@ func (r *announcementRepository) Update(ctx context.Context, a *service.Announce
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
@@ -169,17 +171,18 @@ func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
NotifyMode: m.NotifyMode,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -165,6 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -470,12 +472,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
|
||||
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = usage_5h + $1,
|
||||
usage_1d = usage_1d + $1,
|
||||
usage_7d = usage_7d + $1,
|
||||
window_5h_start = COALESCE(window_5h_start, NOW()),
|
||||
window_1d_start = COALESCE(window_1d_start, NOW()),
|
||||
window_7d_start = COALESCE(window_7d_start, NOW()),
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -489,9 +491,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
@@ -619,6 +621,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
|
||||
@@ -147,17 +147,47 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||
startupCleanupScript = redis.NewScript(`
|
||||
local activePrefix = ARGV[1]
|
||||
local slotTTL = tonumber(ARGV[2])
|
||||
local removed = 0
|
||||
for i = 1, #KEYS do
|
||||
local key = KEYS[i]
|
||||
local members = redis.call('ZRANGE', key, 0, -1)
|
||||
for _, member in ipairs(members) do
|
||||
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||
removed = removed + redis.call('ZREM', key, member)
|
||||
end
|
||||
end
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, slotTTL)
|
||||
end
|
||||
end
|
||||
return removed
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
if activeRequestPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 清理有序集合中非当前进程前缀的成员
|
||||
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||
for _, pattern := range slotPatterns {
|
||||
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||
for _, pattern := range waitPatterns {
|
||||
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
return fmt.Errorf("del %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
|
||||
@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
|
||||
@@ -59,7 +59,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -125,7 +127,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
183
backend/internal/repository/scheduled_test_repo.go
Normal file
183
backend/internal/repository/scheduled_test_repo.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// --- Plan Repository ---
|
||||
|
||||
type scheduledTestPlanRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
|
||||
return &scheduledTestPlanRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
`, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
|
||||
`, id, lastRunAt, nextRunAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Result Repository ---
|
||||
|
||||
type scheduledTestResultRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
|
||||
return &scheduledTestResultRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
|
||||
|
||||
out := &service.ScheduledTestResult{}
|
||||
if err := row.Scan(
|
||||
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
|
||||
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`, planID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var results []*service.ScheduledTestResult
|
||||
for rows.Next() {
|
||||
r := &service.ScheduledTestResult{}
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
|
||||
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM scheduled_test_results
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
) ranked
|
||||
WHERE rn > $2
|
||||
)
|
||||
`, planID, keepCount)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- scan helpers ---
|
||||
|
||||
type scannable interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
|
||||
var plans []*service.ScheduledTestPlan
|
||||
for rows.Next() {
|
||||
p, err := scanPlan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plans = append(plans, p)
|
||||
}
|
||||
return plans, rows.Err()
|
||||
}
|
||||
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
|
||||
simpleModeLegacyAdminConcurrency = 5
|
||||
simpleModeTargetAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
if upgraded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.User.Update().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
|
||||
).
|
||||
SetConcurrency(simpleModeTargetAdminConcurrency).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if err := client.Setting.Create().
|
||||
SetKey(simpleModeAdminConcurrencyUpgradeKey).
|
||||
SetValue(now.Format(time.RFC3339)).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx); err != nil {
|
||||
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
@@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -1363,7 +1366,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1401,6 +1405,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1664,7 +1670,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1747,7 +1754,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
@@ -1762,7 +1770,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
@@ -1806,6 +1815,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
@@ -2497,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
@@ -2536,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
@@ -2606,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if serviceTier.Valid {
|
||||
log.ServiceTier = &serviceTier.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
@@ -2622,7 +2638,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
@@ -2646,6 +2663,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
|
||||
@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Nil(t, log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -125,7 +190,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
@@ -144,7 +209,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
@@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
@@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "flex", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("service_tier_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(3),
|
||||
int64(12),
|
||||
int64(22),
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
|
||||
@@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"allow_messages_dispatch": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -643,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
@@ -1096,6 +1098,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
return int64(len(ids)), nil
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// API Key 管理
|
||||
registerAdminAPIKeyRoutes(admin, h)
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,6 +244,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
@@ -249,6 +253,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
@@ -259,6 +264,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
|
||||
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
@@ -388,6 +395,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Beta 策略配置
|
||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
@@ -478,6 +491,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
plans := admin.Group("/scheduled-test-plans")
|
||||
{
|
||||
plans.POST("", h.Admin.ScheduledTest.Create)
|
||||
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
|
||||
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
|
||||
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
|
||||
}
|
||||
// Nested under accounts
|
||||
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
|
||||
@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
auth.POST("/oauth/linuxdo/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
@@ -43,12 +43,33 @@ func RegisterGatewayRoutes(
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
// /v1/messages: auto-route based on group platform
|
||||
gateway.POST("/messages", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Messages(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Messages(c)
|
||||
})
|
||||
// /v1/messages/count_tokens: OpenAI groups get 404
|
||||
gateway.POST("/messages/count_tokens", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "not_found_error",
|
||||
"message": "Token counting is not supported for this platform",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
h.Gateway.CountTokens(c)
|
||||
})
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
@@ -77,6 +98,7 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
|
||||
// Antigravity 模型列表
|
||||
@@ -132,3 +154,12 @@ func RegisterGatewayRoutes(
|
||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
||||
}
|
||||
|
||||
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
||||
func getGroupPlatform(c *gin.Context) string {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return apiKey.Group.Platform
|
||||
}
|
||||
|
||||
51
backend/internal/server/routes/gateway_test.go
Normal file
51
backend/internal/server/routes/gateway_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRoutesTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
RegisterGatewayRoutes(
|
||||
router,
|
||||
&handler.Handlers{
|
||||
Gateway: &handler.GatewayHandler{},
|
||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||
SoraGateway: &handler.SoraGatewayHandler{},
|
||||
},
|
||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{},
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
|
||||
router := newGatewayRoutesTestRouter()
|
||||
|
||||
for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ type Account struct {
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) EffectiveLoadFactor() int {
|
||||
if a == nil {
|
||||
return 1
|
||||
}
|
||||
if a.LoadFactor != nil && *a.LoadFactor > 0 {
|
||||
return *a.LoadFactor
|
||||
}
|
||||
if a.Concurrency > 0 {
|
||||
return a.Concurrency
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
@@ -633,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPoolModeRetryCount = 3
|
||||
maxPoolModeRetryCount = 10
|
||||
)
|
||||
|
||||
// GetPoolModeRetryCount 返回池模式同账号重试次数。
|
||||
// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。
|
||||
func (a *Account) GetPoolModeRetryCount() int {
|
||||
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
raw, ok := a.Credentials["pool_mode_retry_count"]
|
||||
if !ok || raw == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
count := parsePoolModeRetryCount(raw)
|
||||
if count < 0 {
|
||||
return 0
|
||||
}
|
||||
if count > maxPoolModeRetryCount {
|
||||
return maxPoolModeRetryCount
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePoolModeRetryCount(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
|
||||
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
|
||||
func isPoolModeRetryableStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -853,15 +936,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeCtxPool = "ctx_pool"
|
||||
OpenAIWSIngressModePassthrough = "passthrough"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeCtxPool:
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
case OpenAIWSIngressModePassthrough:
|
||||
return OpenAIWSIngressModePassthrough
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
@@ -873,18 +962,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
// 4. defaultMode(非法时回退 ctx_pool)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
@@ -919,7 +1011,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
return OpenAIWSIngressModeCtxPool, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
@@ -946,6 +1038,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
|
||||
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
@@ -1104,6 +1200,102 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
return "5m"
|
||||
}
|
||||
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
return a.getExtraFloat64("quota_limit")
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
return a.getExtraFloat64("quota_used")
|
||||
}
|
||||
|
||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_daily_limit")
|
||||
}
|
||||
|
||||
// GetQuotaDailyUsed 获取当日已用额度(美元)
|
||||
func (a *Account) GetQuotaDailyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_daily_used")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaWeeklyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_limit")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
|
||||
func (a *Account) GetQuotaWeeklyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_used")
|
||||
}
|
||||
|
||||
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
|
||||
func (a *Account) getExtraFloat64(key string) float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
|
||||
func (a *Account) getExtraTime(key string) time.Time {
|
||||
if a.Extra == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
|
||||
return t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
}
|
||||
|
||||
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期
|
||||
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true // 从未使用过,视为过期(下次 increment 会初始化)
|
||||
}
|
||||
return time.Since(periodStart) >= dur
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
// 总额度
|
||||
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
|
||||
46
backend/internal/service/account_load_factor_test.go
Normal file
46
backend/internal/service/account_load_factor_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func intPtrHelper(v int) *int { return &v }
|
||||
|
||||
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
|
||||
require.Equal(t, 20, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
|
||||
require.Equal(t, 3, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to shared", func(t *testing.T) {
|
||||
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
||||
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||
shared := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
dedicated := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
|
||||
117
backend/internal/service/account_pool_mode_test.go
Normal file
117
backend/internal/service/account_pool_mode_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPoolModeRetryCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "default_when_not_pool_mode",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "default_when_missing_retry_count",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "supports_float64_from_json_credentials",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": float64(5),
|
||||
},
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "supports_json_number",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": json.Number("4"),
|
||||
},
|
||||
},
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "supports_string_value",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "2",
|
||||
},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "negative_value_is_clamped_to_zero",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": -1,
|
||||
},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "oversized_value_is_clamped_to_max",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": 99,
|
||||
},
|
||||
},
|
||||
expected: maxPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "invalid_value_falls_back_to_default",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "oops",
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user