mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
106 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcb6444f89 | ||
|
|
c2b14693b4 | ||
|
|
92d35409de | ||
|
|
351a08f813 | ||
|
|
a58dc787a9 | ||
|
|
7079edc2d0 | ||
|
|
da89583ccc | ||
|
|
a42a1f08e9 | ||
|
|
ebd5253e22 | ||
|
|
6411645ffc | ||
|
|
c0c322ba16 | ||
|
|
d35c5cd491 | ||
|
|
7a353028e7 | ||
|
|
2d8d3b7857 | ||
|
|
4190293b07 | ||
|
|
421b4c0aff | ||
|
|
cd69a7cb85 | ||
|
|
0c9ba9e86c | ||
|
|
1b4d2a41c9 | ||
|
|
0787d2b47a | ||
|
|
97bf1d85ab | ||
|
|
207a493fab | ||
|
|
1f3f9e131e | ||
|
|
4ddedfaaf9 | ||
|
|
3ebebef95f | ||
|
|
9f7ad47598 | ||
|
|
3c83cd8be2 | ||
|
|
963b3b768c | ||
|
|
f6709fb5d6 | ||
|
|
921599948b | ||
|
|
5df3cafa99 | ||
|
|
1a2143c1fe | ||
|
|
dd25281305 | ||
|
|
49d0301dde | ||
|
|
d90e56eb45 | ||
|
|
838ada8864 | ||
|
|
65a106792a | ||
|
|
ee4bfcbb81 | ||
|
|
a087f089b8 | ||
|
|
afbe8bf001 | ||
|
|
2a3ef0be06 | ||
|
|
3403909354 | ||
|
|
005d0c5f53 | ||
|
|
8aaaeb29cc | ||
|
|
230f8abd04 | ||
|
|
a18bbb5f2f | ||
|
|
60fce4f1dc | ||
|
|
9af65efcdb | ||
|
|
bc194a7d8c | ||
|
|
ff1f114989 | ||
|
|
cac230206d | ||
|
|
79ae15d5e8 | ||
|
|
0cce0a8877 | ||
|
|
225fd035ae | ||
|
|
fb7d1346b5 | ||
|
|
491a744481 | ||
|
|
f366026435 | ||
|
|
1a0d4ed668 | ||
|
|
63a8c76946 | ||
|
|
f355a68bc9 | ||
|
|
c87e6526c1 | ||
|
|
af3a5076d6 | ||
|
|
18f2e21414 | ||
|
|
8a8cdeebb4 | ||
|
|
12b33f4ea4 | ||
|
|
01b3a09d7d | ||
|
|
0d6c1c7790 | ||
|
|
95e366b6c6 | ||
|
|
77701143bf | ||
|
|
02dea7b09b | ||
|
|
c26f93c4a0 | ||
|
|
c826ac28ef | ||
|
|
1893b0eb30 | ||
|
|
05527b13db | ||
|
|
ae5d9c8bfc | ||
|
|
9117c2a4ec | ||
|
|
bab4bb9904 | ||
|
|
33bae6f49b | ||
|
|
32d619a56b | ||
|
|
642432cf2a | ||
|
|
61e9598b08 | ||
|
|
d4e34c7514 | ||
|
|
bfe7a5e452 | ||
|
|
77d916ffec | ||
|
|
831abf7977 | ||
|
|
817a491087 | ||
|
|
9a8dacc514 | ||
|
|
8adf80d98b | ||
|
|
62686a6213 | ||
|
|
3a089242f8 | ||
|
|
9d70c38504 | ||
|
|
aeb464f3ca | ||
|
|
7076717b20 | ||
|
|
c0a4fcea0a | ||
|
|
aa2b195c86 | ||
|
|
1d0872e7ca | ||
|
|
33988637b5 | ||
|
|
d4f6ad7225 | ||
|
|
078fefed03 | ||
|
|
5b10af85b4 | ||
|
|
4caf95e5dd | ||
|
|
8e1bcf53bb | ||
|
|
064f9be7e4 | ||
|
|
adcfb44cb7 | ||
|
|
3d79773ba2 | ||
|
|
6aa8cbbf20 |
6
.github/workflows/backend-ci.yml
vendored
6
.github/workflows/backend-ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@@ -38,10 +38,10 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
version: v2.11
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
|
||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -93,20 +93,13 @@ linters:
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
# Temporarily disable style checks to allow CI to pass
|
||||
# "all" enables every SA/ST/S/QF check; only list the ones to disable.
|
||||
checks:
|
||||
- all
|
||||
- -ST1000 # Package comment format
|
||||
@@ -114,489 +107,19 @@ linters:
|
||||
- -ST1020 # Comment on exported method format
|
||||
- -ST1021 # Comment on exported type format
|
||||
- -ST1022 # Comment on exported variable format
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# NOTE: ST1020, ST1021, ST1022 removed (disabled above)
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
field-writes-are-uses: true
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
exported-fields-are-used: true
|
||||
# Default: true
|
||||
parameters-are-used: true
|
||||
# Default: true
|
||||
local-variables-are-used: false
|
||||
# Default: true — must be true, ent generates 130K+ lines of code
|
||||
generated-is-used: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
|
||||
@@ -86,6 +86,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -216,6 +217,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -164,7 +164,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
||||
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -273,6 +278,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -402,6 +408,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -41,6 +41,8 @@ type Account struct {
|
||||
ProxyID *int64 `json:"proxy_id,omitempty"`
|
||||
// Concurrency holds the value of the "concurrency" field.
|
||||
Concurrency int `json:"concurrency,omitempty"`
|
||||
// LoadFactor holds the value of the "load_factor" field.
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
// Priority holds the value of the "priority" field.
|
||||
Priority int `json:"priority,omitempty"`
|
||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||
@@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case account.FieldRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Concurrency = int(value.Int64)
|
||||
}
|
||||
case account.FieldLoadFactor:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field load_factor", values[i])
|
||||
} else if value.Valid {
|
||||
_m.LoadFactor = new(int)
|
||||
*_m.LoadFactor = int(value.Int64)
|
||||
}
|
||||
case account.FieldPriority:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field priority", values[i])
|
||||
@@ -445,6 +454,11 @@ func (_m *Account) String() string {
|
||||
builder.WriteString("concurrency=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Concurrency))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.LoadFactor; v != nil {
|
||||
builder.WriteString("load_factor=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("priority=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -37,6 +37,8 @@ const (
|
||||
FieldProxyID = "proxy_id"
|
||||
// FieldConcurrency holds the string denoting the concurrency field in the database.
|
||||
FieldConcurrency = "concurrency"
|
||||
// FieldLoadFactor holds the string denoting the load_factor field in the database.
|
||||
FieldLoadFactor = "load_factor"
|
||||
// FieldPriority holds the string denoting the priority field in the database.
|
||||
FieldPriority = "priority"
|
||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||
@@ -121,6 +123,7 @@ var Columns = []string{
|
||||
FieldExtra,
|
||||
FieldProxyID,
|
||||
FieldConcurrency,
|
||||
FieldLoadFactor,
|
||||
FieldPriority,
|
||||
FieldRateMultiplier,
|
||||
FieldStatus,
|
||||
@@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldConcurrency, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByLoadFactor orders the results by the load_factor field.
|
||||
func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldLoadFactor, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPriority orders the results by the priority field.
|
||||
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||
|
||||
@@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ.
|
||||
func LoadFactor(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
|
||||
func Priority(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
@@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldConcurrency, v))
|
||||
}
|
||||
|
||||
// LoadFactorEQ applies the EQ predicate on the "load_factor" field.
|
||||
func LoadFactorEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field.
|
||||
func LoadFactorNEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIn applies the In predicate on the "load_factor" field.
|
||||
func LoadFactorIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field.
|
||||
func LoadFactorNotIn(vs ...int) predicate.Account {
|
||||
return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...))
|
||||
}
|
||||
|
||||
// LoadFactorGT applies the GT predicate on the "load_factor" field.
|
||||
func LoadFactorGT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorGTE applies the GTE predicate on the "load_factor" field.
|
||||
func LoadFactorGTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldGTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLT applies the LT predicate on the "load_factor" field.
|
||||
func LoadFactorLT(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLT(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorLTE applies the LTE predicate on the "load_factor" field.
|
||||
func LoadFactorLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldLoadFactor, v))
|
||||
}
|
||||
|
||||
// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field.
|
||||
func LoadFactorIsNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldIsNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field.
|
||||
func LoadFactorNotNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldNotNull(FieldLoadFactor))
|
||||
}
|
||||
|
||||
// PriorityEQ applies the EQ predicate on the "priority" field.
|
||||
func PriorityEQ(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
|
||||
@@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate {
|
||||
_c.mutation.SetLoadFactor(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetLoadFactor(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_c *AccountCreate) SetPriority(v int) *AccountCreate {
|
||||
_c.mutation.SetPriority(v)
|
||||
@@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
|
||||
_node.Concurrency = value
|
||||
}
|
||||
if value, ok := _c.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
_node.LoadFactor = &value
|
||||
}
|
||||
if value, ok := _c.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
_node.Priority = value
|
||||
@@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert {
|
||||
u.Set(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert {
|
||||
u.Add(account.FieldLoadFactor, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert {
|
||||
u.SetNull(account.FieldLoadFactor)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsert) SetPriority(v int) *AccountUpsert {
|
||||
u.Set(account.FieldPriority, v)
|
||||
@@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
@@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddLoadFactor adds v to the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddLoadFactor(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearLoadFactor()
|
||||
})
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
|
||||
@@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
@@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetLoadFactor()
|
||||
_u.mutation.SetLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetLoadFactor(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddLoadFactor adds value to the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne {
|
||||
_u.mutation.AddLoadFactor(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne {
|
||||
_u.mutation.ClearLoadFactor()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne {
|
||||
_u.mutation.ResetPriority()
|
||||
@@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
||||
if value, ok := _u.mutation.AddedConcurrency(); ok {
|
||||
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.LoadFactor(); ok {
|
||||
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedLoadFactor(); ok {
|
||||
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.LoadFactorCleared() {
|
||||
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ type Announcement struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
// 状态: draft, active, archived
|
||||
Status string `json:"status,omitempty"`
|
||||
// 通知模式: silent(仅铃铛), popup(弹窗提醒)
|
||||
NotifyMode string `json:"notify_mode,omitempty"`
|
||||
// 展示条件(JSON 规则)
|
||||
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
|
||||
// 开始展示时间(为空表示立即生效)
|
||||
@@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new([]byte)
|
||||
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode:
|
||||
values[i] = new(sql.NullString)
|
||||
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case announcement.FieldNotifyMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field notify_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.NotifyMode = value.String
|
||||
}
|
||||
case announcement.FieldTargeting:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field targeting", values[i])
|
||||
@@ -213,6 +221,9 @@ func (_m *Announcement) String() string {
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("notify_mode=")
|
||||
builder.WriteString(_m.NotifyMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("targeting=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
FieldContent = "content"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldNotifyMode holds the string denoting the notify_mode field in the database.
|
||||
FieldNotifyMode = "notify_mode"
|
||||
// FieldTargeting holds the string denoting the targeting field in the database.
|
||||
FieldTargeting = "targeting"
|
||||
// FieldStartsAt holds the string denoting the starts_at field in the database.
|
||||
@@ -53,6 +55,7 @@ var Columns = []string{
|
||||
FieldTitle,
|
||||
FieldContent,
|
||||
FieldStatus,
|
||||
FieldNotifyMode,
|
||||
FieldTargeting,
|
||||
FieldStartsAt,
|
||||
FieldEndsAt,
|
||||
@@ -81,6 +84,10 @@ var (
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultNotifyMode holds the default value on creation for the "notify_mode" field.
|
||||
DefaultNotifyMode string
|
||||
// NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
NotifyModeValidator func(string) error
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
@@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByNotifyMode orders the results by the notify_mode field.
|
||||
func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotifyMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStartsAt orders the results by the starts_at field.
|
||||
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
|
||||
|
||||
@@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ.
|
||||
func NotifyMode(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
|
||||
func StartsAt(v time.Time) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
|
||||
@@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyModeEQ applies the EQ predicate on the "notify_mode" field.
|
||||
func NotifyModeEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field.
|
||||
func NotifyModeNEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeIn applies the In predicate on the "notify_mode" field.
|
||||
func NotifyModeIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field.
|
||||
func NotifyModeNotIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeGT applies the GT predicate on the "notify_mode" field.
|
||||
func NotifyModeGT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeGTE applies the GTE predicate on the "notify_mode" field.
|
||||
func NotifyModeGTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLT applies the LT predicate on the "notify_mode" field.
|
||||
func NotifyModeLT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLTE applies the LTE predicate on the "notify_mode" field.
|
||||
func NotifyModeLTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContains applies the Contains predicate on the "notify_mode" field.
|
||||
func NotifyModeContains(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasPrefix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasSuffix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field.
|
||||
func NotifyModeEqualFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field.
|
||||
func NotifyModeContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
|
||||
func TargetingIsNil() predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
|
||||
|
||||
@@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate {
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate {
|
||||
if v != nil {
|
||||
_c.SetNotifyMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate {
|
||||
_c.mutation.SetTargeting(v)
|
||||
@@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() {
|
||||
v := announcement.DefaultStatus
|
||||
_c.mutation.SetStatus(v)
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
v := announcement.DefaultNotifyMode
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
v := announcement.DefaultCreatedAt()
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec)
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
_node.Status = value
|
||||
}
|
||||
if value, ok := _c.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
_node.NotifyMode = value
|
||||
}
|
||||
if value, ok := _c.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
_node.Targeting = value
|
||||
@@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldNotifyMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert {
|
||||
u.SetExcluded(announcement.FieldNotifyMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldTargeting, v)
|
||||
@@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
@@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
|
||||
@@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
@@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ type Group struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -186,13 +190,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||
values[i] = new(sql.NullString)
|
||||
case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -415,6 +419,18 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.SortOrder = int(value.Int64)
|
||||
}
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AllowMessagesDispatch = value.Bool
|
||||
}
|
||||
case group.FieldDefaultMappedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DefaultMappedModel = value.String
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -608,6 +624,12 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sort_order=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("allow_messages_dispatch=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -75,6 +75,10 @@ const (
|
||||
FieldSupportedModelScopes = "supported_model_scopes"
|
||||
// FieldSortOrder holds the string denoting the sort_order field in the database.
|
||||
FieldSortOrder = "sort_order"
|
||||
// FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database.
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -180,6 +184,8 @@ var Columns = []string{
|
||||
FieldMcpXMLInject,
|
||||
FieldSupportedModelScopes,
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldDefaultMappedModel,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -247,6 +253,12 @@ var (
|
||||
DefaultSupportedModelScopes []string
|
||||
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
|
||||
DefaultSortOrder int
|
||||
// DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field.
|
||||
DefaultAllowMessagesDispatch bool
|
||||
// DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field.
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
DefaultMappedModelValidator func(string) error
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -397,6 +409,16 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field.
|
||||
func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDefaultMappedModel orders the results by the default_mapped_model field.
|
||||
func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -195,6 +195,16 @@ func SortOrder(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ.
|
||||
func AllowMessagesDispatch(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ.
|
||||
func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1470,6 +1480,81 @@ func SortOrderLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNotIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContains(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasPrefix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasSuffix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEqualFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -424,6 +424,34 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate {
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate {
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -613,6 +641,14 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultSortOrder
|
||||
_c.mutation.SetSortOrder(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
v := group.DefaultAllowMessagesDispatch
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -683,6 +719,17 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.SortOrder(); !ok {
|
||||
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -830,6 +877,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
_node.SortOrder = value
|
||||
}
|
||||
if value, ok := _c.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
_node.AllowMessagesDispatch = value
|
||||
}
|
||||
if value, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1520,6 +1575,30 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldAllowMessagesDispatch, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldAllowMessagesDispatch)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert {
|
||||
u.Set(group.FieldDefaultMappedModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldDefaultMappedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2188,6 +2267,34 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -3022,6 +3129,34 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -625,6 +625,34 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -910,6 +938,11 @@ func (_u *GroupUpdate) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1110,6 +1143,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -2014,6 +2053,34 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2312,6 +2379,11 @@ func (_u *GroupUpdateOne) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2529,6 +2601,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -106,6 +106,7 @@ var (
|
||||
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||
{Name: "load_factor", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
@@ -132,7 +133,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "accounts_proxies_proxy",
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -151,52 +152,52 @@ var (
|
||||
{
|
||||
Name: "account_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_proxy_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[27]},
|
||||
Columns: []*schema.Column{AccountsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_last_used_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[15]},
|
||||
Columns: []*schema.Column{AccountsColumns[16]},
|
||||
},
|
||||
{
|
||||
Name: "account_schedulable",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limited_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limit_reset_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
},
|
||||
{
|
||||
Name: "account_overload_until",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
Columns: []*schema.Column{AccountsColumns[22]},
|
||||
},
|
||||
{
|
||||
Name: "account_platform_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
|
||||
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
|
||||
Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]},
|
||||
},
|
||||
{
|
||||
Name: "account_deleted_at",
|
||||
@@ -250,6 +251,7 @@ var (
|
||||
{Name: "title", Type: field.TypeString, Size: 200},
|
||||
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
|
||||
{Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"},
|
||||
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -272,17 +274,17 @@ var (
|
||||
{
|
||||
Name: "announcement_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[9]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_starts_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[5]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_ends_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[7]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -406,6 +408,8 @@ var (
|
||||
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
|
||||
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -2260,6 +2260,8 @@ type AccountMutation struct {
|
||||
extra *map[string]interface{}
|
||||
concurrency *int
|
||||
addconcurrency *int
|
||||
load_factor *int
|
||||
addload_factor *int
|
||||
priority *int
|
||||
addpriority *int
|
||||
rate_multiplier *float64
|
||||
@@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() {
|
||||
m.addconcurrency = nil
|
||||
}
|
||||
|
||||
// SetLoadFactor sets the "load_factor" field.
|
||||
func (m *AccountMutation) SetLoadFactor(i int) {
|
||||
m.load_factor = &i
|
||||
m.addload_factor = nil
|
||||
}
|
||||
|
||||
// LoadFactor returns the value of the "load_factor" field in the mutation.
|
||||
func (m *AccountMutation) LoadFactor() (r int, exists bool) {
|
||||
v := m.load_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldLoadFactor returns the old "load_factor" field's value of the Account entity.
|
||||
// If the Account object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldLoadFactor requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err)
|
||||
}
|
||||
return oldValue.LoadFactor, nil
|
||||
}
|
||||
|
||||
// AddLoadFactor adds i to the "load_factor" field.
|
||||
func (m *AccountMutation) AddLoadFactor(i int) {
|
||||
if m.addload_factor != nil {
|
||||
*m.addload_factor += i
|
||||
} else {
|
||||
m.addload_factor = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation.
|
||||
func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) {
|
||||
v := m.addload_factor
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ClearLoadFactor clears the value of the "load_factor" field.
|
||||
func (m *AccountMutation) ClearLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
m.clearedFields[account.FieldLoadFactor] = struct{}{}
|
||||
}
|
||||
|
||||
// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation.
|
||||
func (m *AccountMutation) LoadFactorCleared() bool {
|
||||
_, ok := m.clearedFields[account.FieldLoadFactor]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetLoadFactor resets all changes to the "load_factor" field.
|
||||
func (m *AccountMutation) ResetLoadFactor() {
|
||||
m.load_factor = nil
|
||||
m.addload_factor = nil
|
||||
delete(m.clearedFields, account.FieldLoadFactor)
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (m *AccountMutation) SetPriority(i int) {
|
||||
m.priority = &i
|
||||
@@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AccountMutation) Fields() []string {
|
||||
fields := make([]string, 0, 27)
|
||||
fields := make([]string, 0, 28)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, account.FieldCreatedAt)
|
||||
}
|
||||
@@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string {
|
||||
if m.concurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.load_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.priority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ProxyID()
|
||||
case account.FieldConcurrency:
|
||||
return m.Concurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.LoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.Priority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
|
||||
return m.OldProxyID(ctx)
|
||||
case account.FieldConcurrency:
|
||||
return m.OldConcurrency(ctx)
|
||||
case account.FieldLoadFactor:
|
||||
return m.OldLoadFactor(ctx)
|
||||
case account.FieldPriority:
|
||||
return m.OldPriority(ctx)
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string {
|
||||
if m.addconcurrency != nil {
|
||||
fields = append(fields, account.FieldConcurrency)
|
||||
}
|
||||
if m.addload_factor != nil {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.addpriority != nil {
|
||||
fields = append(fields, account.FieldPriority)
|
||||
}
|
||||
@@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
|
||||
switch name {
|
||||
case account.FieldConcurrency:
|
||||
return m.AddedConcurrency()
|
||||
case account.FieldLoadFactor:
|
||||
return m.AddedLoadFactor()
|
||||
case account.FieldPriority:
|
||||
return m.AddedPriority()
|
||||
case account.FieldRateMultiplier:
|
||||
@@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddConcurrency(v)
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddLoadFactor(v)
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
@@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(account.FieldProxyID) {
|
||||
fields = append(fields, account.FieldProxyID)
|
||||
}
|
||||
if m.FieldCleared(account.FieldLoadFactor) {
|
||||
fields = append(fields, account.FieldLoadFactor)
|
||||
}
|
||||
if m.FieldCleared(account.FieldErrorMessage) {
|
||||
fields = append(fields, account.FieldErrorMessage)
|
||||
}
|
||||
@@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error {
|
||||
case account.FieldProxyID:
|
||||
m.ClearProxyID()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ClearLoadFactor()
|
||||
return nil
|
||||
case account.FieldErrorMessage:
|
||||
m.ClearErrorMessage()
|
||||
return nil
|
||||
@@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error {
|
||||
case account.FieldConcurrency:
|
||||
m.ResetConcurrency()
|
||||
return nil
|
||||
case account.FieldLoadFactor:
|
||||
m.ResetLoadFactor()
|
||||
return nil
|
||||
case account.FieldPriority:
|
||||
m.ResetPriority()
|
||||
return nil
|
||||
@@ -5060,6 +5167,7 @@ type AnnouncementMutation struct {
|
||||
title *string
|
||||
content *string
|
||||
status *string
|
||||
notify_mode *string
|
||||
targeting *domain.AnnouncementTargeting
|
||||
starts_at *time.Time
|
||||
ends_at *time.Time
|
||||
@@ -5284,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() {
|
||||
m.status = nil
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) SetNotifyMode(s string) {
|
||||
m.notify_mode = &s
|
||||
}
|
||||
|
||||
// NotifyMode returns the value of the "notify_mode" field in the mutation.
|
||||
func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) {
|
||||
v := m.notify_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity.
|
||||
// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldNotifyMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err)
|
||||
}
|
||||
return oldValue.NotifyMode, nil
|
||||
}
|
||||
|
||||
// ResetNotifyMode resets all changes to the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) ResetNotifyMode() {
|
||||
m.notify_mode = nil
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) {
|
||||
m.targeting = &dt
|
||||
@@ -5731,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AnnouncementMutation) Fields() []string {
|
||||
fields := make([]string, 0, 10)
|
||||
fields := make([]string, 0, 11)
|
||||
if m.title != nil {
|
||||
fields = append(fields, announcement.FieldTitle)
|
||||
}
|
||||
@@ -5741,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string {
|
||||
if m.status != nil {
|
||||
fields = append(fields, announcement.FieldStatus)
|
||||
}
|
||||
if m.notify_mode != nil {
|
||||
fields = append(fields, announcement.FieldNotifyMode)
|
||||
}
|
||||
if m.targeting != nil {
|
||||
fields = append(fields, announcement.FieldTargeting)
|
||||
}
|
||||
@@ -5776,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Content()
|
||||
case announcement.FieldStatus:
|
||||
return m.Status()
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.NotifyMode()
|
||||
case announcement.FieldTargeting:
|
||||
return m.Targeting()
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5805,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V
|
||||
return m.OldContent(ctx)
|
||||
case announcement.FieldStatus:
|
||||
return m.OldStatus(ctx)
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.OldNotifyMode(ctx)
|
||||
case announcement.FieldTargeting:
|
||||
return m.OldTargeting(ctx)
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5849,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetStatus(v)
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetNotifyMode(v)
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
v, ok := value.(domain.AnnouncementTargeting)
|
||||
if !ok {
|
||||
@@ -6016,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error {
|
||||
case announcement.FieldStatus:
|
||||
m.ResetStatus()
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
m.ResetNotifyMode()
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
m.ResetTargeting()
|
||||
return nil
|
||||
@@ -8089,6 +8250,8 @@ type GroupMutation struct {
|
||||
appendsupported_model_scopes []string
|
||||
sort_order *int
|
||||
addsort_order *int
|
||||
allow_messages_dispatch *bool
|
||||
default_mapped_model *string
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -9833,6 +9996,78 @@ func (m *GroupMutation) ResetSortOrder() {
|
||||
m.addsort_order = nil
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
|
||||
m.allow_messages_dispatch = &b
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
|
||||
func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
|
||||
v := m.allow_messages_dispatch
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
|
||||
}
|
||||
return oldValue.AllowMessagesDispatch, nil
|
||||
}
|
||||
|
||||
// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) ResetAllowMessagesDispatch() {
|
||||
m.allow_messages_dispatch = nil
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (m *GroupMutation) SetDefaultMappedModel(s string) {
|
||||
m.default_mapped_model = &s
|
||||
}
|
||||
|
||||
// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
|
||||
func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
|
||||
v := m.default_mapped_model
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
|
||||
}
|
||||
return oldValue.DefaultMappedModel, nil
|
||||
}
|
||||
|
||||
// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
|
||||
func (m *GroupMutation) ResetDefaultMappedModel() {
|
||||
m.default_mapped_model = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -10191,7 +10426,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 30)
|
||||
fields := make([]string, 0, 32)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -10282,6 +10517,12 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.sort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
if m.allow_messages_dispatch != nil {
|
||||
fields = append(fields, group.FieldAllowMessagesDispatch)
|
||||
}
|
||||
if m.default_mapped_model != nil {
|
||||
fields = append(fields, group.FieldDefaultMappedModel)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -10350,6 +10591,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.SupportedModelScopes()
|
||||
case group.FieldSortOrder:
|
||||
return m.SortOrder()
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.AllowMessagesDispatch()
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.DefaultMappedModel()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -10419,6 +10664,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldSupportedModelScopes(ctx)
|
||||
case group.FieldSortOrder:
|
||||
return m.OldSortOrder(ctx)
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.OldAllowMessagesDispatch(ctx)
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -10638,6 +10887,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetSortOrder(v)
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAllowMessagesDispatch(v)
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetDefaultMappedModel(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -11065,6 +11328,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldSortOrder:
|
||||
m.ResetSortOrder()
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
m.ResetAllowMessagesDispatch()
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
m.ResetDefaultMappedModel()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -212,29 +212,29 @@ func init() {
|
||||
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
|
||||
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
|
||||
// accountDescPriority is the schema descriptor for priority field.
|
||||
accountDescPriority := accountFields[8].Descriptor()
|
||||
accountDescPriority := accountFields[9].Descriptor()
|
||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||
accountDescRateMultiplier := accountFields[10].Descriptor()
|
||||
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||
// accountDescStatus is the schema descriptor for status field.
|
||||
accountDescStatus := accountFields[10].Descriptor()
|
||||
accountDescStatus := accountFields[11].Descriptor()
|
||||
// account.DefaultStatus holds the default value on creation for the status field.
|
||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||
accountDescAutoPauseOnExpired := accountFields[15].Descriptor()
|
||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||
accountDescSchedulable := accountFields[15].Descriptor()
|
||||
accountDescSchedulable := accountFields[16].Descriptor()
|
||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||
accountDescSessionWindowStatus := accountFields[23].Descriptor()
|
||||
accountDescSessionWindowStatus := accountFields[24].Descriptor()
|
||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||
@@ -277,12 +277,18 @@ func init() {
|
||||
announcement.DefaultStatus = announcementDescStatus.Default.(string)
|
||||
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
|
||||
// announcementDescNotifyMode is the schema descriptor for notify_mode field.
|
||||
announcementDescNotifyMode := announcementFields[3].Descriptor()
|
||||
// announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field.
|
||||
announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string)
|
||||
// announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error)
|
||||
// announcementDescCreatedAt is the schema descriptor for created_at field.
|
||||
announcementDescCreatedAt := announcementFields[8].Descriptor()
|
||||
announcementDescCreatedAt := announcementFields[9].Descriptor()
|
||||
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
|
||||
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
announcementDescUpdatedAt := announcementFields[9].Descriptor()
|
||||
announcementDescUpdatedAt := announcementFields[10].Descriptor()
|
||||
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
|
||||
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
@@ -447,6 +453,16 @@ func init() {
|
||||
groupDescSortOrder := groupFields[26].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
|
||||
@@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field {
|
||||
field.Int("concurrency").
|
||||
Default(3),
|
||||
|
||||
field.Int("load_factor").Optional().Nillable(),
|
||||
|
||||
// priority: 账户优先级,数值越小优先级越高
|
||||
// 调度器会优先使用高优先级的账户
|
||||
field.Int("priority").
|
||||
|
||||
@@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementStatusDraft).
|
||||
Comment("状态: draft, active, archived"),
|
||||
field.String("notify_mode").
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementNotifyModeSilent).
|
||||
Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"),
|
||||
field.JSON("targeting", domain.AnnouncementTargeting{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
|
||||
@@ -148,6 +148,15 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
Default(false).
|
||||
Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"),
|
||||
field.String("default_mapped_model").
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.25.7
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -38,8 +39,6 @@ require (
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.40.0
|
||||
google.golang.org/grpc v1.75.1
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
@@ -53,7 +52,6 @@ require (
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
|
||||
@@ -109,7 +107,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
@@ -169,6 +166,7 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
@@ -178,8 +176,8 @@ require (
|
||||
golang.org/x/mod v0.32.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
@@ -171,8 +171,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -182,8 +180,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -398,8 +394,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||
@@ -455,8 +449,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||
|
||||
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||
// Enabled: 全局总开关(默认 true)
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
@@ -1335,7 +1335,7 @@ func setDefaults() {
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||
switch mode {
|
||||
case "off", "shared", "dedicated":
|
||||
case "off", "ctx_pool", "passthrough":
|
||||
case "shared", "dedicated":
|
||||
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||
}
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||
|
||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||
},
|
||||
{
|
||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||
},
|
||||
|
||||
@@ -13,6 +13,11 @@ const (
|
||||
AnnouncementStatusArchived = "archived"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = "silent"
|
||||
AnnouncementNotifyModePopup = "popup"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = "subscription"
|
||||
AnnouncementConditionTypeBalance = "balance"
|
||||
@@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error {
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
|
||||
@@ -102,6 +102,7 @@ type CreateAccountRequest struct {
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -120,7 +121,8 @@ type UpdateAccountRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -135,6 +137,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
@@ -240,77 +243,77 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
if !lite {
|
||||
// Get current concurrency counts for all accounts
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
|
||||
// 始终获取并发数(Redis ZCARD,极低开销)
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
}
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
}
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
}
|
||||
|
||||
// 始终获取 RPM 计数(Redis GET,极低开销)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 始终获取活跃会话数(Redis ZCARD,低开销)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 始终获取窗口费用(PostgreSQL 聚合查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 RPM 计数(批量查询)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
@@ -506,6 +509,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
@@ -575,6 +579,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
@@ -1101,6 +1106,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.RateMultiplier != nil ||
|
||||
req.LoadFactor != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
@@ -1119,6 +1125,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
LoadFactor: req.LoadFactor,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
@@ -1328,6 +1335,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// ResetQuota handles resetting account quota usage
|
||||
// POST /api/v1/admin/accounts/:id/reset-quota
|
||||
func (h *AccountHandler) ResetQuota(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
|
||||
response.InternalError(c, "Failed to reset account quota: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
// GET /api/v1/admin/accounts/:id/temp-unschedulable
|
||||
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
|
||||
|
||||
@@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
@@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A
|
||||
}
|
||||
|
||||
type CreateAnnouncementRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
}
|
||||
|
||||
type UpdateAnnouncementRequest struct {
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
}
|
||||
|
||||
// List handles listing announcements with filters
|
||||
@@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.CreateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil && *req.StartsAt > 0 {
|
||||
@@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.UpdateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil {
|
||||
|
||||
@@ -53,6 +53,9 @@ type CreateGroupRequest struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -88,6 +91,9 @@ type UpdateGroupRequest struct {
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -203,6 +209,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -254,6 +262,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ScheduledTestHandler handles admin scheduled-test-plan management.
|
||||
type ScheduledTestHandler struct {
|
||||
scheduledTestSvc *service.ScheduledTestService
|
||||
}
|
||||
|
||||
// NewScheduledTestHandler creates a new ScheduledTestHandler.
|
||||
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
|
||||
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
|
||||
}
|
||||
|
||||
type createScheduledTestPlanRequest struct {
|
||||
AccountID int64 `json:"account_id" binding:"required"`
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid account id")
|
||||
return
|
||||
}
|
||||
|
||||
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, plans)
|
||||
}
|
||||
|
||||
// Create POST /admin/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
var req createScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
plan := &service.ScheduledTestPlan{
|
||||
AccountID: req.AccountID,
|
||||
ModelID: req.ModelID,
|
||||
CronExpression: req.CronExpression,
|
||||
Enabled: true,
|
||||
MaxResults: req.MaxResults,
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// Update PUT /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "plan not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelID != "" {
|
||||
existing.ModelID = req.ModelID
|
||||
}
|
||||
if req.CronExpression != "" {
|
||||
existing.CronExpression = req.CronExpression
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
existing.Enabled = *req.Enabled
|
||||
}
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// Delete DELETE /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// ListResults GET /admin/scheduled-test-plans/:id/results
|
||||
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
@@ -819,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -905,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type Announcement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
|
||||
@@ -25,9 +26,10 @@ type Announcement struct {
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
@@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement {
|
||||
return nil
|
||||
}
|
||||
return &Announcement{
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
NotifyMode: a.NotifyMode,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement
|
||||
return nil
|
||||
}
|
||||
return &UserAnnouncement{
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
NotifyMode: a.Announcement.NotifyMode,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,9 +89,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
RateLimit5h: k.RateLimit5h,
|
||||
RateLimit1d: k.RateLimit1d,
|
||||
RateLimit7d: k.RateLimit7d,
|
||||
Usage5h: k.Usage5h,
|
||||
Usage1d: k.Usage1d,
|
||||
Usage7d: k.Usage7d,
|
||||
Usage5h: k.EffectiveUsage5h(),
|
||||
Usage1d: k.EffectiveUsage1d(),
|
||||
Usage7d: k.EffectiveUsage7d(),
|
||||
Window5hStart: k.Window5hStart,
|
||||
Window1dStart: k.Window1dStart,
|
||||
Window7dStart: k.Window7dStart,
|
||||
@@ -125,8 +125,9 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
@@ -164,6 +165,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -183,6 +185,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
LoadFactor: a.LoadFactor,
|
||||
Priority: a.Priority,
|
||||
RateMultiplier: a.BillingRateMultiplier(),
|
||||
Status: a.Status,
|
||||
@@ -248,6 +251,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -96,6 +96,9 @@ type Group struct {
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -112,6 +115,9 @@ type AdminGroup struct {
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -131,6 +137,7 @@ type Account struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
@@ -185,6 +192,10 @@ type Account struct {
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -971,7 +971,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
if err == nil && rateLimitData != nil {
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.Usage5h
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
@@ -981,7 +981,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
})
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.Usage1d
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
@@ -991,7 +991,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
})
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.Usage7d
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
|
||||
@@ -27,6 +27,7 @@ type AdminHandlers struct {
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var handlerStructuredLogCaptureMu sync.Mutex
|
||||
|
||||
type handlerInMemoryLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
cloned := *event
|
||||
if event.Fields != nil {
|
||||
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||
for k, v := range event.Fields {
|
||||
cloned.Fields[k] = v
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.events = append(s.events, &cloned)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
wantLevel := strings.ToLower(strings.TrimSpace(level))
|
||||
for _, ev := range s.events {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev == nil || ev.Fields == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
|
||||
t.Helper()
|
||||
handlerStructuredLogCaptureMu.Lock()
|
||||
|
||||
err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: logger.SamplingOptions{Enabled: false},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sink := &handlerInMemoryLogSink{}
|
||||
logger.SetSink(sink)
|
||||
return sink, func() {
|
||||
logger.SetSink(nil)
|
||||
handlerStructuredLogCaptureMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
|
||||
require.False(t, isOpenAIRemoteCompactPath(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
require.False(t, isOpenAIRemoteCompactPath(c))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||
c.Set(opsAccountIDKey, int64(123))
|
||||
c.Header("x-request-id", "rid-compact-ok")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
|
||||
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Status(http.StatusBadGateway)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
}
|
||||
@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
compactStartedAt := time.Now()
|
||||
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
@@ -114,6 +118,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
sessionHashBody := body
|
||||
if service.IsOpenAIResponsesCompactPathForTest(c) {
|
||||
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
|
||||
c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed)
|
||||
}
|
||||
normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body)
|
||||
if compactErr != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body")
|
||||
return
|
||||
}
|
||||
if normalizedCompact {
|
||||
body = normalizedCompactBody
|
||||
}
|
||||
}
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
@@ -189,7 +207,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
@@ -301,6 +319,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
@@ -340,6 +361,396 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
return strings.HasSuffix(normalizedPath, "/responses/compact")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
|
||||
if !isOpenAIRemoteCompactPath(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
path string
|
||||
status int
|
||||
)
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
ctx = c.Request.Context()
|
||||
if c.Request.URL != nil {
|
||||
path = strings.TrimSpace(c.Request.URL.Path)
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
status = c.Writer.Status()
|
||||
}
|
||||
}
|
||||
|
||||
outcome := "failed"
|
||||
if status >= 200 && status < 300 {
|
||||
outcome = "succeeded"
|
||||
}
|
||||
latencyMs := time.Since(startedAt).Milliseconds()
|
||||
if latencyMs < 0 {
|
||||
latencyMs = 0
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Bool("remote_compact", true),
|
||||
zap.String("compact_outcome", outcome),
|
||||
zap.Int("status_code", status),
|
||||
zap.Int64("latency_ms", latencyMs),
|
||||
zap.String("path", path),
|
||||
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
|
||||
fields = append(fields, zap.String("request_user_agent", userAgent))
|
||||
}
|
||||
if v, ok := c.Get(opsModelKey); ok {
|
||||
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
|
||||
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(opsAccountIDKey); ok {
|
||||
if accountID, ok := v.(int64); ok && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if outcome == "succeeded" {
|
||||
log.Info("codex.remote_compact.succeeded")
|
||||
return
|
||||
}
|
||||
log.Warn("codex.remote_compact.failed")
|
||||
}
|
||||
|
||||
// Messages handles Anthropic Messages API requests routed to OpenAI platform.
|
||||
// POST /v1/messages (when group platform is OpenAI)
|
||||
func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverAnthropicMessagesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.messages",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查分组是否允许 /v1/messages 调度
|
||||
if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch {
|
||||
h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group does not allow /v1/messages dispatch")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
|
||||
c.Set("openai_messages_fallback_model", "")
|
||||
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_messages.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_messages_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_messages.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_messages.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicErrorResponse writes an error in Anthropic Messages API format.
|
||||
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// anthropicStreamingAwareError handles errors that may occur during streaming,
|
||||
// using Anthropic SSE error format.
|
||||
func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errPayload, _ := json.Marshal(gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format.
|
||||
func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode)
|
||||
h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written.
|
||||
func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
@@ -756,6 +1167,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
@@ -817,6 +1231,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart
|
||||
)
|
||||
}
|
||||
|
||||
// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages
|
||||
// handler and returns an Anthropic-formatted error response.
|
||||
func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := streamStarted != nil && *streamStarted
|
||||
requestLogger(c, "handler.openai_gateway.messages").Error(
|
||||
"openai.messages_panic_recovered",
|
||||
zap.Bool("stream_started", started),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
if !started {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
|
||||
@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
|
||||
|
||||
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
|
||||
|
||||
@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
|
||||
@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -119,23 +119,33 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// Finish 结束处理,返回最终事件和用量
|
||||
// Finish 结束处理,返回最终事件和用量。
|
||||
// 若整个流未收到任何可解析的上游数据(messageStartSent == false),
|
||||
// 则不补发任何结束事件,防止客户端收到没有 message_start 的残缺流。
|
||||
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
if !p.messageStartSent {
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
}
|
||||
|
||||
// MessageStartSent 报告流中是否已发出过 message_start 事件(即是否收到过有效的上游数据)
|
||||
func (p *StreamingProcessor) MessageStartSent() bool {
|
||||
return p.messageStartSent
|
||||
}
|
||||
|
||||
// emitMessageStart 发送 message_start 事件
|
||||
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
|
||||
if p.messageStartSent {
|
||||
|
||||
735
backend/internal/pkg/apicompat/anthropic_responses_test.go
Normal file
735
backend/internal/pkg/apicompat/anthropic_responses_test.go
Normal file
@@ -0,0 +1,735 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AnthropicToResponses tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_BasicText(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Stream: true,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-5.2", resp.Model)
|
||||
assert.True(t, resp.Stream)
|
||||
assert.Equal(t, 1024, *resp.MaxOutputTokens)
|
||||
assert.False(t, *resp.Store)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
|
||||
t.Run("string", func(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 100,
|
||||
System: json.RawMessage(`"You are helpful."`),
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
})
|
||||
|
||||
t.Run("array", func(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 100,
|
||||
System: json.RawMessage(`[{"type":"text","text":"Part 1"},{"type":"text","text":"Part 2"}]`),
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
// System text should be joined with double newline.
|
||||
var text string
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &text))
|
||||
assert.Equal(t, "Part 1\n\nPart 2", text)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"What is the weather?"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"Let me check."},{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[{"type":"tool_result","tool_use_id":"call_1","content":"Sunny, 72°F"}]`)},
|
||||
},
|
||||
Tools: []AnthropicTool{
|
||||
{Name: "get_weather", Description: "Get weather", InputSchema: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check tools
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||
|
||||
// Check input items
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + assistant + function_call + function_call_output = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ThinkingIgnored(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"deep thought"},{"type":"text","text":"Hi!"}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`"More"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + assistant(text only, thinking ignored) + user = 3
|
||||
require.Len(t, items, 3)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
// Assistant content should only have text, not thinking.
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "Hi!", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_MaxTokensFloor(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 10, // below minMaxOutputTokens (128)
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 128, *resp.MaxOutputTokens)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResponsesToAnthropic (non-streaming) tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesToAnthropic_TextOnly(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_123",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Hello there!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
assert.Equal(t, "resp_123", anth.ID)
|
||||
assert.Equal(t, "claude-opus-4-6", anth.Model)
|
||||
assert.Equal(t, "end_turn", anth.StopReason)
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.Equal(t, "text", anth.Content[0].Type)
|
||||
assert.Equal(t, "Hello there!", anth.Content[0].Text)
|
||||
assert.Equal(t, 10, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Let me check."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_1",
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city":"NYC"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
assert.Equal(t, "tool_use", anth.StopReason)
|
||||
require.Len(t, anth.Content, 2)
|
||||
assert.Equal(t, "text", anth.Content[0].Type)
|
||||
assert.Equal(t, "tool_use", anth.Content[1].Type)
|
||||
assert.Equal(t, "call_1", anth.Content[1].ID)
|
||||
assert.Equal(t, "get_weather", anth.Content[1].Name)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_789",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "reasoning",
|
||||
Summary: []ResponsesSummary{
|
||||
{Type: "summary_text", Text: "Thinking about the answer..."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "42"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
require.Len(t, anth.Content, 2)
|
||||
assert.Equal(t, "thinking", anth.Content[0].Type)
|
||||
assert.Equal(t, "Thinking about the answer...", anth.Content[0].Thinking)
|
||||
assert.Equal(t, "text", anth.Content[1].Type)
|
||||
assert.Equal(t, "42", anth.Content[1].Text)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_Incomplete(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_inc",
|
||||
Model: "gpt-5.2",
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{
|
||||
Reason: "max_output_tokens",
|
||||
},
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "Partial..."}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
assert.Equal(t, "max_tokens", anth.StopReason)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_EmptyOutput(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_empty",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.Equal(t, "text", anth.Content[0].Type)
|
||||
assert.Equal(t, "", anth.Content[0].Text)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesEventToAnthropicEvents tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestStreamingTextOnly(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
// 1. response.created
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{
|
||||
ID: "resp_1",
|
||||
Model: "gpt-5.2",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "message_start", events[0].Type)
|
||||
|
||||
// 2. output_item.added (message)
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "message"},
|
||||
}, state)
|
||||
assert.Len(t, events, 0) // message item doesn't emit events
|
||||
|
||||
// 3. text delta
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Hello",
|
||||
}, state)
|
||||
require.Len(t, events, 2) // content_block_start + content_block_delta
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
assert.Equal(t, "text", events[0].ContentBlock.Type)
|
||||
assert.Equal(t, "content_block_delta", events[1].Type)
|
||||
assert.Equal(t, "text_delta", events[1].Delta.Type)
|
||||
assert.Equal(t, "Hello", events[1].Delta.Text)
|
||||
|
||||
// 4. more text
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: " world",
|
||||
}, state)
|
||||
require.Len(t, events, 1) // only delta, no new block start
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
|
||||
// 5. text done
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
|
||||
// 6. completed
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2) // message_delta + message_stop
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||
assert.Equal(t, 10, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 5, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingToolCall(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
// 1. response.created
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_2", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
// 2. function_call added
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "function_call", CallID: "call_1", Name: "get_weather"},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
assert.Equal(t, "tool_use", events[0].ContentBlock.Type)
|
||||
assert.Equal(t, "call_1", events[0].ContentBlock.ID)
|
||||
|
||||
// 3. arguments delta
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 0,
|
||||
Delta: `{"city":`,
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
|
||||
assert.Equal(t, `{"city":`, events[0].Delta.PartialJSON)
|
||||
|
||||
// 4. arguments done
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
|
||||
// 5. completed with tool_calls
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 10},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamingReasoning(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_3", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
// reasoning item added
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "reasoning"},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
assert.Equal(t, "thinking", events[0].ContentBlock.Type)
|
||||
|
||||
// reasoning text delta
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
OutputIndex: 0,
|
||||
Delta: "Let me think...",
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
assert.Equal(t, "thinking_delta", events[0].Delta.Type)
|
||||
assert.Equal(t, "Let me think...", events[0].Delta.Thinking)
|
||||
|
||||
// reasoning done
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamingIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_4", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Partial output...",
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.incomplete",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 100, OutputTokens: 4096},
|
||||
},
|
||||
}, state)
|
||||
|
||||
// Should close the text block + message_delta + message_stop
|
||||
require.Len(t, events, 3)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
assert.Equal(t, "message_delta", events[1].Type)
|
||||
assert.Equal(t, "max_tokens", events[1].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[2].Type)
|
||||
}
|
||||
|
||||
func TestFinalizeStream_NeverStarted(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
events := FinalizeResponsesAnthropicStream(state)
|
||||
assert.Nil(t, events)
|
||||
}
|
||||
|
||||
func TestFinalizeStream_AlreadyCompleted(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.MessageStartSent = true
|
||||
state.MessageStopSent = true
|
||||
events := FinalizeResponsesAnthropicStream(state)
|
||||
assert.Nil(t, events)
|
||||
}
|
||||
|
||||
func TestFinalizeStream_AbnormalTermination(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
// Simulate a stream that started but never completed
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_5", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Interrupted...",
|
||||
}, state)
|
||||
|
||||
// Stream ends without response.completed
|
||||
events := FinalizeResponsesAnthropicStream(state)
|
||||
require.Len(t, events, 3) // content_block_stop + message_delta + message_stop
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
assert.Equal(t, "message_delta", events[1].Type)
|
||||
assert.Equal(t, "end_turn", events[1].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[2].Type)
|
||||
}
|
||||
|
||||
func TestStreamingEmptyResponse(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_6", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 5, OutputTokens: 0},
|
||||
},
|
||||
}, state)
|
||||
|
||||
require.Len(t, events, 2) // message_delta + message_stop
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestResponsesAnthropicEventToSSE(t *testing.T) {
|
||||
evt := AnthropicStreamEvent{
|
||||
Type: "message_start",
|
||||
Message: &AnthropicResponse{
|
||||
ID: "resp_1",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
sse, err := ResponsesAnthropicEventToSSE(evt)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, sse, "event: message_start\n")
|
||||
assert.Contains(t, sse, "data: ")
|
||||
assert.Contains(t, sse, `"resp_1"`)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// response.failed tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestStreamingFailed(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
// 1. response.created
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_fail_1", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
// 2. Some text output before failure
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Partial output before failure",
|
||||
}, state)
|
||||
|
||||
// 3. response.failed
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.failed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "failed",
|
||||
Error: &ResponsesError{Code: "server_error", Message: "Internal error"},
|
||||
Usage: &ResponsesUsage{InputTokens: 50, OutputTokens: 10},
|
||||
},
|
||||
}, state)
|
||||
|
||||
// Should close text block + message_delta + message_stop
|
||||
require.Len(t, events, 3)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
assert.Equal(t, "message_delta", events[1].Type)
|
||||
assert.Equal(t, "end_turn", events[1].Delta.StopReason)
|
||||
assert.Equal(t, 50, events[1].Usage.InputTokens)
|
||||
assert.Equal(t, 10, events[1].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[2].Type)
|
||||
}
|
||||
|
||||
func TestStreamingFailedNoOutput(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
// 1. response.created
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_fail_2", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
// 2. response.failed with no prior output
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.failed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "failed",
|
||||
Error: &ResponsesError{Code: "rate_limit_error", Message: "Too many requests"},
|
||||
Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 0},
|
||||
},
|
||||
}, state)
|
||||
|
||||
// Should emit message_delta + message_stop (no block to close)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_Failed(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_fail_3",
|
||||
Model: "gpt-5.2",
|
||||
Status: "failed",
|
||||
Error: &ResponsesError{Code: "server_error", Message: "Something went wrong"},
|
||||
Output: []ResponsesOutput{},
|
||||
Usage: &ResponsesUsage{InputTokens: 30, OutputTokens: 0},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
// Failed status defaults to "end_turn" stop reason
|
||||
assert.Equal(t, "end_turn", anth.StopReason)
|
||||
// Should have at least an empty text block
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.Equal(t, "text", anth.Content[0].Type)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// thinking → reasoning conversion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "adaptive", BudgetTokens: 5000},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "disabled"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tool_choice conversion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_ToolChoiceAuto(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
ToolChoice: json.RawMessage(`{"type":"auto"}`),
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc string
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "auto", tc)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolChoiceAny(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
ToolChoice: json.RawMessage(`{"type":"any"}`),
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc string
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "required", tc)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
ToolChoice: json.RawMessage(`{"type":"tool","name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
fn, ok := tc["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
}
|
||||
346
backend/internal/pkg/apicompat/anthropic_to_responses.go
Normal file
346
backend/internal/pkg/apicompat/anthropic_to_responses.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AnthropicToResponses converts an Anthropic Messages request directly into
|
||||
// a Responses API request. This preserves fields that would be lost in a
|
||||
// Chat Completions intermediary round-trip (e.g. thinking, cache_control,
|
||||
// structured system prompts).
|
||||
func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertAnthropicToResponsesInput(req.System, req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
v := req.MaxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Convert thinking → reasoning.
|
||||
// generate_summary="auto" causes the upstream to emit reasoning_summary_text
|
||||
// streaming events; the include array only needs reasoning.encrypted_content
|
||||
// (already set above) for content continuity.
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "enabled":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "high", Summary: "auto"}
|
||||
case "adaptive":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "medium", Summary: "auto"}
|
||||
}
|
||||
// "disabled" or unknown → omit reasoning
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if len(req.ToolChoice) > 0 {
|
||||
tc, err := convertAnthropicToolChoiceToResponses(req.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert tool_choice: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertAnthropicToolChoiceToResponses maps Anthropic tool_choice to Responses format.
|
||||
//
|
||||
// {"type":"auto"} → "auto"
|
||||
// {"type":"any"} → "required"
|
||||
// {"type":"none"} → "none"
|
||||
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch tc.Type {
|
||||
case "auto":
|
||||
return json.Marshal("auto")
|
||||
case "any":
|
||||
return json.Marshal("required")
|
||||
case "none":
|
||||
return json.Marshal("none")
|
||||
case "tool":
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": tc.Name},
|
||||
})
|
||||
default:
|
||||
// Pass through unknown types as-is
|
||||
return raw, nil
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToResponsesInput builds the Responses API input items array
|
||||
// from the Anthropic system field and message list.
|
||||
func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
|
||||
// System prompt → system role input item.
|
||||
if len(system) > 0 {
|
||||
sysText, err := parseAnthropicSystemPrompt(system)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sysText != "" {
|
||||
content, _ := json.Marshal(sysText)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Role: "system",
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
items, err := anthropicMsgToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseAnthropicSystemPrompt handles the Anthropic system field which can be
|
||||
// a plain string or an array of text blocks.
|
||||
func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parts []string
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n"), nil
|
||||
}
|
||||
|
||||
// anthropicMsgToResponsesItems converts a single Anthropic message into one
|
||||
// or more Responses API input items.
|
||||
func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
return anthropicUserToResponses(m.Content)
|
||||
case "assistant":
|
||||
return anthropicAssistantToResponses(m.Content)
|
||||
default:
|
||||
return anthropicUserToResponses(m.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
text := extractAnthropicToolResultText(b)
|
||||
if text == "" {
|
||||
// OpenAI Responses API requires "output" field; use placeholder for empty results.
|
||||
text = "(empty)"
|
||||
}
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: text,
|
||||
})
|
||||
}
|
||||
|
||||
// Remaining text blocks → user message.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
content, _ := json.Marshal(text)
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// anthropicAssistantToResponses handles an Anthropic assistant message.
|
||||
// Text content → assistant message with output_text parts.
|
||||
// tool_use blocks → function_call items.
|
||||
// thinking blocks → ignored (OpenAI doesn't accept them as input).
|
||||
func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Text content → assistant message with output_text content parts.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: text}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
|
||||
// tool_use → function_call items.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_use" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
fcID := toResponsesCallID(b.ID)
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a
|
||||
// Responses API function_call ID that starts with "fc_".
|
||||
func toResponsesCallID(id string) string {
|
||||
if strings.HasPrefix(id, "fc_") {
|
||||
return id
|
||||
}
|
||||
return "fc_" + id
|
||||
}
|
||||
|
||||
// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix
|
||||
// that was added during request conversion.
|
||||
func fromResponsesCallID(id string) string {
|
||||
if after, ok := strings.CutPrefix(id, "fc_"); ok {
|
||||
// Only strip if the remainder doesn't look like it was already "fc_" prefixed.
|
||||
// E.g. "fc_toolu_xxx" → "toolu_xxx", "fc_call_xxx" → "call_xxx"
|
||||
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
|
||||
return after
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// extractAnthropicToolResultText gets the text content from a tool_result block.
|
||||
func extractAnthropicToolResultText(b AnthropicContentBlock) string {
|
||||
if len(b.Content) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err == nil {
|
||||
var parts []string
|
||||
for _, ib := range inner {
|
||||
if ib.Type == "text" && ib.Text != "" {
|
||||
parts = append(parts, ib.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
// tool_use/tool_result blocks.
|
||||
func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
var parts []string
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
for _, t := range tools {
|
||||
// Anthropic server tools like "web_search_20250305" → OpenAI {"type":"web_search"}
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
out = append(out, ResponsesTool{Type: "web_search"})
|
||||
continue
|
||||
}
|
||||
out = append(out, ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.InputSchema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
516
backend/internal/pkg/apicompat/responses_to_anthropic.go
Normal file
516
backend/internal/pkg/apicompat/responses_to_anthropic.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → AnthropicResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToAnthropic converts a Responses API response directly into an
|
||||
// Anthropic Messages response. Reasoning output items are mapped to thinking
|
||||
// blocks; function_call items become tool_use blocks.
|
||||
func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicResponse {
|
||||
out := &AnthropicResponse{
|
||||
ID: resp.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "reasoning":
|
||||
summaryText := ""
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
summaryText += s.Text
|
||||
}
|
||||
}
|
||||
if summaryText != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: summaryText,
|
||||
})
|
||||
}
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
})
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(item.CallID),
|
||||
Name: item.Name,
|
||||
Input: json.RawMessage(item.Arguments),
|
||||
})
|
||||
case "web_search_call":
|
||||
toolUseID := "srvtoolu_" + item.ID
|
||||
query := ""
|
||||
if item.Action != nil {
|
||||
query = item.Action.Query
|
||||
}
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: inputJSON,
|
||||
})
|
||||
emptyResults, _ := json.Marshal([]struct{}{})
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: emptyResults,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
|
||||
}
|
||||
out.Content = blocks
|
||||
|
||||
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
|
||||
|
||||
if resp.Usage != nil {
|
||||
out.Usage = AnthropicUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "max_tokens"
|
||||
}
|
||||
return "end_turn"
|
||||
case "completed":
|
||||
if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" {
|
||||
return "tool_use"
|
||||
}
|
||||
return "end_turn"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToAnthropicState tracks state for converting a sequence of
|
||||
// Responses SSE events directly into Anthropic SSE events.
|
||||
type ResponsesEventToAnthropicState struct {
|
||||
MessageStartSent bool
|
||||
MessageStopSent bool
|
||||
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
|
||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||
OutputIndexToBlockIdx map[int]int
|
||||
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheReadInputTokens int
|
||||
|
||||
ResponseID string
|
||||
Model string
|
||||
Created int64
|
||||
}
|
||||
|
||||
// NewResponsesEventToAnthropicState returns an initialised stream state.
|
||||
func NewResponsesEventToAnthropicState() *ResponsesEventToAnthropicState {
|
||||
return &ResponsesEventToAnthropicState{
|
||||
OutputIndexToBlockIdx: make(map[int]int),
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToAnthropicEvents converts a single Responses SSE event into
|
||||
// zero or more Anthropic SSE events, updating state as it goes.
|
||||
func ResponsesEventToAnthropicEvents(
|
||||
evt *ResponsesStreamEvent,
|
||||
state *ResponsesEventToAnthropicState,
|
||||
) []AnthropicStreamEvent {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToAnthHandleCreated(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToAnthHandleOutputItemAdded(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToAnthHandleTextDelta(evt, state)
|
||||
case "response.output_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToAnthHandleFuncArgsDelta(evt, state)
|
||||
case "response.function_call_arguments.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.output_item.done":
|
||||
return resToAnthHandleOutputItemDone(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToAnthHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToAnthHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesAnthropicStream emits synthetic termination events if the
|
||||
// stream ended without a proper completion event.
|
||||
func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.MessageStartSent || state.MessageStopSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: "end_turn",
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
CacheReadInputTokens: state.CacheReadInputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicStreamEvent{Type: "message_stop"},
|
||||
)
|
||||
state.MessageStopSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
// ResponsesAnthropicEventToSSE formats an AnthropicStreamEvent as an SSE line pair.
|
||||
func ResponsesAnthropicEventToSSE(evt AnthropicStreamEvent) (string, error) {
|
||||
data, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToAnthHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Response != nil {
|
||||
state.ResponseID = evt.Response.ID
|
||||
// Only use upstream model if no override was set (e.g. originalModel)
|
||||
if state.Model == "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
|
||||
if state.MessageStartSent {
|
||||
return nil
|
||||
}
|
||||
state.MessageStartSent = true
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "message_start",
|
||||
Message: &AnthropicResponse{
|
||||
ID: state.ResponseID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: []AnthropicContentBlock{},
|
||||
Model: state.Model,
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch evt.Item.Type {
|
||||
case "function_call":
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "tool_use"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(evt.Item.CallID),
|
||||
Name: evt.Item.Name,
|
||||
Input: json.RawMessage("{}"),
|
||||
},
|
||||
})
|
||||
return events
|
||||
|
||||
case "reasoning":
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "thinking"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
},
|
||||
})
|
||||
return events
|
||||
|
||||
case "message":
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resToAnthHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
|
||||
if !state.ContentBlockOpen || state.CurrentBlockType != "text" {
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "text"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: &idx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "text_delta",
|
||||
Text: evt.Delta,
|
||||
},
|
||||
})
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &blockIdx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: evt.Delta,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &blockIdx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: evt.Delta,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleBlockDone(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.ContentBlockOpen {
|
||||
return nil
|
||||
}
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
|
||||
func resToAnthHandleOutputItemDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle web_search_call → synthesize server_tool_use + web_search_tool_result blocks.
|
||||
if evt.Item.Type == "web_search_call" && evt.Item.Status == "completed" {
|
||||
return resToAnthHandleWebSearchDone(evt, state)
|
||||
}
|
||||
|
||||
if state.ContentBlockOpen {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resToAnthHandleWebSearchDone converts an OpenAI web_search_call output item
|
||||
// into Anthropic server_tool_use + web_search_tool_result content block pairs.
|
||||
// This allows Claude Code to count the searches performed.
|
||||
func resToAnthHandleWebSearchDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
toolUseID := "srvtoolu_" + evt.Item.ID
|
||||
query := ""
|
||||
if evt.Item.Action != nil {
|
||||
query = evt.Item.Action.Query
|
||||
}
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
|
||||
// Emit server_tool_use block (start + stop).
|
||||
idx1 := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx1,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: inputJSON,
|
||||
},
|
||||
})
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx1,
|
||||
})
|
||||
state.ContentBlockIndex++
|
||||
|
||||
// Emit web_search_tool_result block (start + stop).
|
||||
// Content is empty because OpenAI does not expose individual search results;
|
||||
// the model consumes them internally and produces text output.
|
||||
emptyResults, _ := json.Marshal([]struct{}{})
|
||||
idx2 := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx2,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: emptyResults,
|
||||
},
|
||||
})
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx2,
|
||||
})
|
||||
state.ContentBlockIndex++
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if state.MessageStopSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
stopReason := "end_turn"
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
state.InputTokens = evt.Response.Usage.InputTokens
|
||||
state.OutputTokens = evt.Response.Usage.OutputTokens
|
||||
if evt.Response.Usage.InputTokensDetails != nil {
|
||||
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
case "completed":
|
||||
if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
CacheReadInputTokens: state.CacheReadInputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicStreamEvent{Type: "message_stop"},
|
||||
)
|
||||
state.MessageStopSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.ContentBlockOpen {
|
||||
return nil
|
||||
}
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = false
|
||||
state.ContentBlockIndex++
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx,
|
||||
}}
|
||||
}
|
||||
320
backend/internal/pkg/apicompat/types.go
Normal file
320
backend/internal/pkg/apicompat/types.go
Normal file
@@ -0,0 +1,320 @@
|
||||
// Package apicompat provides type definitions and conversion utilities for
|
||||
// translating between Anthropic Messages and OpenAI Responses API formats.
|
||||
// It enables multi-protocol support so that clients using different API
|
||||
// formats can be served through a unified gateway.
|
||||
package apicompat
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic Messages API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
type AnthropicThinking struct {
|
||||
Type string `json:"type"` // "enabled" | "adaptive" | "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // max thinking tokens
|
||||
}
|
||||
|
||||
// AnthropicMessage is a single message in the Anthropic conversation.
|
||||
type AnthropicMessage struct {
|
||||
Role string `json:"role"` // "user" | "assistant"
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// AnthropicContentBlock is one block inside a message's content array.
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// type=text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// type=tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []AnthropicContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
}
|
||||
|
||||
// AnthropicResponse is the non-streaming response from POST /v1/messages.
|
||||
type AnthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Content []AnthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// AnthropicUsage holds token counts in Anthropic format.
|
||||
type AnthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic SSE event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicStreamEvent is a single SSE event in the Anthropic streaming protocol.
|
||||
type AnthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// message_start
|
||||
Message *AnthropicResponse `json:"message,omitempty"`
|
||||
|
||||
// content_block_start
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"`
|
||||
|
||||
// content_block_delta
|
||||
Delta *AnthropicDelta `json:"delta,omitempty"`
|
||||
|
||||
// message_delta
|
||||
Usage *AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicDelta carries incremental content in streaming events.
|
||||
type AnthropicDelta struct {
|
||||
Type string `json:"type,omitempty"` // "text_delta" | "input_json_delta" | "thinking_delta" | "signature_delta"
|
||||
|
||||
// text_delta
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// input_json_delta
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
|
||||
// thinking_delta
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// signature_delta
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// message_delta fields
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Responses API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesRequest is the request body for POST /v1/responses.
|
||||
type ResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []ResponsesTool `json:"tools,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
type ResponsesReasoning struct {
|
||||
Effort string `json:"effort"` // "low" | "medium" | "high"
|
||||
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
|
||||
}
|
||||
|
||||
// ResponsesInputItem is one item in the Responses API input array.
|
||||
// The Type field determines which other fields are populated.
|
||||
type ResponsesInputItem struct {
|
||||
// Common
|
||||
Type string `json:"type,omitempty"` // "" for role-based messages
|
||||
|
||||
// Role-based messages (system/user/assistant)
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart
|
||||
|
||||
// type=function_call
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
|
||||
// type=function_call_output
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function" | "web_search" | "local_shell" etc.
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesResponse is the non-streaming response from POST /v1/responses.
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "response"
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"` // "completed" | "incomplete" | "failed"
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
|
||||
// incomplete_details is present when status="incomplete"
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
|
||||
// Error is present when status="failed"
|
||||
Error *ResponsesError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError describes an error in a failed response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails explains why a response is incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"` // "max_output_tokens" | "content_filter"
|
||||
}
|
||||
|
||||
// ResponsesOutput is one output item in a Responses API response.
|
||||
type ResponsesOutput struct {
|
||||
Type string `json:"type"` // "message" | "reasoning" | "function_call" | "web_search_call"
|
||||
|
||||
// type=message
|
||||
ID string `json:"id,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesContentPart `json:"content,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// type=reasoning
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
Summary []ResponsesSummary `json:"summary,omitempty"`
|
||||
|
||||
// type=function_call
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
||||
// type=web_search_call
|
||||
Action *WebSearchAction `json:"action,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchAction describes the search action in a web_search_call output item.
|
||||
type WebSearchAction struct {
|
||||
Type string `json:"type,omitempty"` // "search"
|
||||
Query string `json:"query,omitempty"` // primary search query
|
||||
}
|
||||
|
||||
// ResponsesSummary is a summary text block inside a reasoning output.
|
||||
type ResponsesSummary struct {
|
||||
Type string `json:"type"` // "summary_text"
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ResponsesUsage holds token counts in Responses API format.
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
|
||||
// Optional detailed breakdown
|
||||
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails breaks down input token usage.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputTokensDetails breaks down output token usage.
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Responses SSE event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesStreamEvent is a single SSE event in the Responses streaming protocol.
|
||||
// The Type field corresponds to the "type" in the JSON payload.
|
||||
type ResponsesStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// response.created / response.completed / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
|
||||
// response.output_text.delta / response.output_text.done
|
||||
OutputIndex int `json:"output_index,omitempty"`
|
||||
ContentIndex int `json:"content_index,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ItemID string `json:"item_id,omitempty"`
|
||||
|
||||
// response.function_call_arguments.delta / done
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
||||
// response.reasoning_summary_text.delta / done
|
||||
// Reuses Text/Delta fields above, SummaryIndex identifies which summary part
|
||||
SummaryIndex int `json:"summary_index,omitempty"`
|
||||
|
||||
// error event fields
|
||||
Code string `json:"code,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
|
||||
// Sequence number for ordering events
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// minMaxOutputTokens is the floor for max_output_tokens in a Responses request.
|
||||
// Very small values may cause upstream API errors, so we enforce a minimum.
|
||||
const minMaxOutputTokens = 128
|
||||
@@ -15,6 +15,7 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
|
||||
@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
|
||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
|
||||
// official Codex client family request.
|
||||
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
|
||||
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
|
||||
}
|
||||
|
||||
func normalizeCodexClientHeader(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
originator string
|
||||
want bool
|
||||
}{
|
||||
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
|
||||
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
|
||||
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
|
||||
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,25 +57,28 @@ type DashboardStats struct {
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
|
||||
@@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
if account.LoadFactor != nil {
|
||||
builder.SetLoadFactor(*account.LoadFactor)
|
||||
} else {
|
||||
builder.ClearLoadFactor()
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -437,6 +445,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
switch status {
|
||||
case "rate_limited":
|
||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||
case "temp_unschedulable":
|
||||
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.And(
|
||||
entsql.Not(entsql.IsNull(col)),
|
||||
entsql.GT(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}))
|
||||
default:
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
}
|
||||
@@ -640,7 +656,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
SetStatus(service.StatusActive).
|
||||
SetErrorMessage("").
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
@@ -1205,6 +1231,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.RateMultiplier)
|
||||
idx++
|
||||
}
|
||||
if updates.LoadFactor != nil {
|
||||
if *updates.LoadFactor <= 0 {
|
||||
setClauses = append(setClauses, "load_factor = NULL")
|
||||
} else {
|
||||
setClauses = append(setClauses, "load_factor = $"+itoa(idx))
|
||||
args = append(args, *updates.LoadFactor)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if updates.Status != nil {
|
||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||
args = append(args, *updates.Status)
|
||||
@@ -1527,6 +1562,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
LoadFactor: m.LoadFactor,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
@@ -1639,3 +1675,60 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 重置配额后触发调度快照刷新,使账号重新参与调度
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ func (r *announcementRepository) Create(ctx context.Context, a *service.Announce
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
@@ -64,6 +65,7 @@ func (r *announcementRepository) Update(ctx context.Context, a *service.Announce
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
@@ -169,17 +171,18 @@ func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
NotifyMode: m.NotifyMode,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -165,6 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -470,12 +472,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
|
||||
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = usage_5h + $1,
|
||||
usage_1d = usage_1d + $1,
|
||||
usage_7d = usage_7d + $1,
|
||||
window_5h_start = COALESCE(window_5h_start, NOW()),
|
||||
window_1d_start = COALESCE(window_1d_start, NOW()),
|
||||
window_7d_start = COALESCE(window_7d_start, NOW()),
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -619,6 +621,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
|
||||
@@ -59,7 +59,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -125,7 +127,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
183
backend/internal/repository/scheduled_test_repo.go
Normal file
183
backend/internal/repository/scheduled_test_repo.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// --- Plan Repository ---
|
||||
|
||||
type scheduledTestPlanRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
|
||||
return &scheduledTestPlanRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
`, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
|
||||
`, id, lastRunAt, nextRunAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Result Repository ---
|
||||
|
||||
type scheduledTestResultRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
|
||||
return &scheduledTestResultRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
|
||||
|
||||
out := &service.ScheduledTestResult{}
|
||||
if err := row.Scan(
|
||||
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
|
||||
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`, planID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var results []*service.ScheduledTestResult
|
||||
for rows.Next() {
|
||||
r := &service.ScheduledTestResult{}
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
|
||||
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM scheduled_test_results
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
) ranked
|
||||
WHERE rn > $2
|
||||
)
|
||||
`, planID, keepCount)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- scan helpers ---
|
||||
|
||||
type scannable interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
|
||||
var plans []*service.ScheduledTestPlan
|
||||
for rows.Next() {
|
||||
p, err := scanPlan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plans = append(plans, p)
|
||||
}
|
||||
return plans, rows.Err()
|
||||
}
|
||||
@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1664,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1747,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
@@ -1762,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
@@ -1806,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
@@ -2622,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
@@ -2646,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
|
||||
@@ -125,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
@@ -144,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
|
||||
@@ -1096,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
return int64(len(ids)), nil
|
||||
|
||||
@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// API Key 管理
|
||||
registerAdminAPIKeyRoutes(admin, h)
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
@@ -478,6 +482,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
plans := admin.Group("/scheduled-test-plans")
|
||||
{
|
||||
plans.POST("", h.Admin.ScheduledTest.Create)
|
||||
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
|
||||
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
|
||||
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
|
||||
}
|
||||
// Nested under accounts
|
||||
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
|
||||
@@ -43,12 +43,33 @@ func RegisterGatewayRoutes(
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
// /v1/messages: auto-route based on group platform
|
||||
gateway.POST("/messages", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Messages(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Messages(c)
|
||||
})
|
||||
// /v1/messages/count_tokens: OpenAI groups get 404
|
||||
gateway.POST("/messages/count_tokens", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "not_found_error",
|
||||
"message": "Token counting is not supported for this platform",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
h.Gateway.CountTokens(c)
|
||||
})
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
@@ -77,6 +98,7 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
|
||||
// Antigravity 模型列表
|
||||
@@ -132,3 +154,12 @@ func RegisterGatewayRoutes(
|
||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
||||
}
|
||||
|
||||
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
||||
func getGroupPlatform(c *gin.Context) string {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return apiKey.Group.Platform
|
||||
}
|
||||
|
||||
51
backend/internal/server/routes/gateway_test.go
Normal file
51
backend/internal/server/routes/gateway_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRoutesTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
RegisterGatewayRoutes(
|
||||
router,
|
||||
&handler.Handlers{
|
||||
Gateway: &handler.GatewayHandler{},
|
||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||
SoraGateway: &handler.SoraGatewayHandler{},
|
||||
},
|
||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{},
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
|
||||
router := newGatewayRoutesTestRouter()
|
||||
|
||||
for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ type Account struct {
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) EffectiveLoadFactor() int {
|
||||
if a == nil {
|
||||
return 1
|
||||
}
|
||||
if a.LoadFactor != nil && *a.LoadFactor > 0 {
|
||||
return *a.LoadFactor
|
||||
}
|
||||
if a.Concurrency > 0 {
|
||||
return a.Concurrency
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
@@ -853,15 +867,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeCtxPool = "ctx_pool"
|
||||
OpenAIWSIngressModePassthrough = "passthrough"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeCtxPool:
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
case OpenAIWSIngressModePassthrough:
|
||||
return OpenAIWSIngressModePassthrough
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
@@ -873,18 +893,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
// 4. defaultMode(非法时回退 ctx_pool)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
@@ -919,7 +942,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
return OpenAIWSIngressModeCtxPool, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
@@ -946,6 +969,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
|
||||
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
@@ -1104,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
return "5m"
|
||||
}
|
||||
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
|
||||
46
backend/internal/service/account_load_factor_test.go
Normal file
46
backend/internal/service/account_load_factor_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func intPtrHelper(v int) *int { return &v }
|
||||
|
||||
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
|
||||
require.Equal(t, 20, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
|
||||
require.Equal(t, 3, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to shared", func(t *testing.T) {
|
||||
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
||||
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||
shared := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
dedicated := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
|
||||
@@ -68,6 +68,10 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
@@ -78,6 +82,7 @@ type AccountBulkUpdate struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
|
||||
@@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A
|
||||
panic("unexpected BulkUpdate call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 false(账号不存在)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -33,7 +34,7 @@ import (
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||
@@ -179,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
return s.routeAntigravityTest(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformSora {
|
||||
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -1176,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
||||
return soraerror.TruncateBody(body, max)
|
||||
}
|
||||
|
||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
}
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
@@ -1560,3 +1573,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
|
||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||
return fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
// RunTestBackground executes an account test in-memory (no real HTTP client),
|
||||
// capturing SSE output via httptest.NewRecorder, then parses the result.
|
||||
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
|
||||
startedAt := time.Now()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(w)
|
||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
||||
|
||||
finishedAt := time.Now()
|
||||
body := w.Body.String()
|
||||
responseText, errMsg := parseTestSSEOutput(body)
|
||||
|
||||
status := "success"
|
||||
if testErr != nil || errMsg != "" {
|
||||
status = "failed"
|
||||
if errMsg == "" && testErr != nil {
|
||||
errMsg = testErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return &ScheduledTestResult{
|
||||
Status: status,
|
||||
ResponseText: responseText,
|
||||
ErrorMessage: errMsg,
|
||||
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
|
||||
StartedAt: startedAt,
|
||||
FinishedAt: finishedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseTestSSEOutput extracts response text and error message from captured SSE output.
|
||||
func parseTestSSEOutput(body string) (responseText, errMsg string) {
|
||||
var texts []string
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
var event TestEvent
|
||||
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "content":
|
||||
if event.Text != "" {
|
||||
texts = append(texts, event.Text)
|
||||
}
|
||||
case "error":
|
||||
errMsg = event.Error
|
||||
}
|
||||
}
|
||||
responseText = strings.Join(texts, "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
httppool "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
@@ -70,8 +77,10 @@ type accountWindowStatsBatchReader interface {
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
// 同时支持缓存错误响应(负缓存),防止 429 等错误导致的重试风暴
|
||||
type apiUsageCache struct {
|
||||
response *ClaudeUsageResponse
|
||||
err error // 非 nil 表示缓存的错误(负缓存)
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
@@ -88,15 +97,21 @@ type antigravityUsageCache struct {
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
openAICodexProbeVersion = "0.104.0"
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -224,6 +239,14 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth {
|
||||
usage, err := s.getOpenAIUsage(ctx, account)
|
||||
if err == nil {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
|
||||
if account.Platform == PlatformGemini {
|
||||
usage, err := s.getGeminiUsage(ctx, account)
|
||||
if err == nil {
|
||||
@@ -245,24 +268,65 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
if account.CanGetUsage() {
|
||||
var apiResp *ClaudeUsageResponse
|
||||
|
||||
// 1. 检查 API 缓存(10 分钟)
|
||||
// 1. 检查缓存(成功响应 3 分钟 / 错误响应 1 分钟)
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
if cache, ok := cached.(*apiUsageCache); ok {
|
||||
age := time.Since(cache.timestamp)
|
||||
if cache.err != nil && age < apiErrorCacheTTL {
|
||||
// 负缓存命中:返回缓存的错误,避免重试风暴
|
||||
return nil, cache.err
|
||||
}
|
||||
if cache.response != nil && age < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果没有缓存,从 API 获取
|
||||
// 2. 如果没有有效缓存,通过 singleflight 从 API 获取(防止并发击穿)
|
||||
if apiResp == nil {
|
||||
apiResp, err = s.fetchOAuthUsageRaw(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 随机延迟:打散多账号并发请求,避免同一时刻大量相同 TLS 指纹请求
|
||||
// 触发上游反滥用检测。延迟范围 0~800ms,仅在缓存未命中时生效。
|
||||
jitter := time.Duration(rand.Int64N(int64(apiQueryMaxJitter)))
|
||||
select {
|
||||
case <-time.After(jitter):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
// 缓存 API 响应
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: apiResp,
|
||||
timestamp: time.Now(),
|
||||
|
||||
flightKey := fmt.Sprintf("usage:%d", accountID)
|
||||
result, flightErr, _ := s.cache.apiFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(可能在等待 singleflight 期间被其他请求填充)
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok {
|
||||
age := time.Since(cache.timestamp)
|
||||
if cache.err != nil && age < apiErrorCacheTTL {
|
||||
return nil, cache.err
|
||||
}
|
||||
if cache.response != nil && age < apiCacheTTL {
|
||||
return cache.response, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
resp, fetchErr := s.fetchOAuthUsageRaw(ctx, account)
|
||||
if fetchErr != nil {
|
||||
// 负缓存:缓存错误响应,防止后续请求重复触发 429
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
err: fetchErr,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return nil, fetchErr
|
||||
}
|
||||
// 缓存成功响应
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: resp,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return resp, nil
|
||||
})
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
apiResp, _ = result.(*ClaudeUsageResponse)
|
||||
}
|
||||
|
||||
// 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds)
|
||||
@@ -288,6 +352,161 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{UpdatedAt: &now}
|
||||
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
mergeAccountExtra(account, updates)
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.usageLogRepo == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.FiveHour == nil {
|
||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.FiveHour.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.SevenDay == nil {
|
||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.SevenDay.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
}
|
||||
if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok {
|
||||
if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.cache.openAIProbeCache.Store(accountID, now)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
if account == nil || !account.IsOAuth() {
|
||||
return nil, nil
|
||||
}
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
modelID := openaipkg.DefaultTestModel
|
||||
payload := createOpenAITestPayload(modelID, true)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
}
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
req.Header.Set("Originator", "codex_cli_rs")
|
||||
req.Header.Set("Version", openAICodexProbeVersion)
|
||||
req.Header.Set("User-Agent", codexCLIUserAgent)
|
||||
if s.identityCache != nil {
|
||||
if fp, fpErr := s.identityCache.GetFingerprint(reqCtx, account.ID); fpErr == nil && fp != nil && strings.TrimSpace(fp.UserAgent) != "" {
|
||||
req.Header.Set("User-Agent", strings.TrimSpace(fp.UserAgent))
|
||||
}
|
||||
}
|
||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
client, err := httppool.GetClient(httppool.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||
if account == nil || len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any, len(updates))
|
||||
}
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
@@ -519,6 +738,72 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
||||
}
|
||||
}
|
||||
|
||||
func hasMeaningfulWindowStats(stats *WindowStats) bool {
|
||||
if stats == nil {
|
||||
return false
|
||||
}
|
||||
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
|
||||
}
|
||||
|
||||
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
usedPercentKey string
|
||||
resetAfterKey string
|
||||
resetAtKey string
|
||||
)
|
||||
|
||||
switch window {
|
||||
case "5h":
|
||||
usedPercentKey = "codex_5h_used_percent"
|
||||
resetAfterKey = "codex_5h_reset_after_seconds"
|
||||
resetAtKey = "codex_5h_reset_at"
|
||||
case "7d":
|
||||
usedPercentKey = "codex_7d_used_percent"
|
||||
resetAfterKey = "codex_7d_reset_after_seconds"
|
||||
resetAtKey = "codex_7d_reset_at"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
usedRaw, ok := extra[usedPercentKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
progress := &UsageProgress{Utilization: parseExtraFloat64(usedRaw)}
|
||||
if resetAtRaw, ok := extra[resetAtKey]; ok {
|
||||
if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil {
|
||||
progress.ResetsAt = &resetAt
|
||||
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
|
||||
if progress.RemainingSeconds < 0 {
|
||||
progress.RemainingSeconds = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if progress.ResetsAt == nil {
|
||||
if resetAfterSeconds := parseExtraInt(extra[resetAfterKey]); resetAfterSeconds > 0 {
|
||||
base := now
|
||||
if updatedAtRaw, ok := extra["codex_usage_updated_at"]; ok {
|
||||
if updatedAt, err := parseTime(fmt.Sprint(updatedAtRaw)); err == nil {
|
||||
base = updatedAt
|
||||
}
|
||||
}
|
||||
resetAt := base.Add(time.Duration(resetAfterSeconds) * time.Second)
|
||||
progress.ResetsAt = &resetAt
|
||||
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
|
||||
if progress.RemainingSeconds < 0 {
|
||||
progress.RemainingSeconds = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return progress
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
@@ -666,15 +951,30 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||
// 优先使用响应头中存储的真实 utilization 值(0-1 小数,转为 0-100 百分比)
|
||||
var utilization float64
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
default:
|
||||
utilization = 0.0
|
||||
var found bool
|
||||
if stored, ok := account.Extra["session_window_utilization"]; ok {
|
||||
switch v := stored.(type) {
|
||||
case float64:
|
||||
utilization = v * 100
|
||||
found = true
|
||||
case json.Number:
|
||||
if f, err := v.Float64(); err == nil {
|
||||
utilization = f * 100
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有存储的 utilization,回退到状态估算
|
||||
if !found {
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
}
|
||||
}
|
||||
|
||||
info.FiveHour = &UsageProgress{
|
||||
|
||||
@@ -84,6 +84,7 @@ type AdminService interface {
|
||||
DeleteRedeemCode(ctx context.Context, id int64) error
|
||||
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
ResetAccountQuota(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
@@ -144,6 +145,9 @@ type CreateGroupInput struct {
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
DefaultMappedModel string
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -180,6 +184,9 @@ type UpdateGroupInput struct {
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool
|
||||
DefaultMappedModel *string
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -195,6 +202,7 @@ type CreateAccountInput struct {
|
||||
Concurrency int
|
||||
Priority int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
@@ -215,6 +223,7 @@ type UpdateAccountInput struct {
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
@@ -230,6 +239,7 @@ type BulkUpdateAccountsInput struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
LoadFactor *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
@@ -905,6 +915,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||
DefaultMappedModel: input.DefaultMappedModel,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -1118,6 +1130,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.SupportedModelScopes = *input.SupportedModelScopes
|
||||
}
|
||||
|
||||
// OpenAI Messages 调度配置
|
||||
if input.AllowMessagesDispatch != nil {
|
||||
group.AllowMessagesDispatch = *input.AllowMessagesDispatch
|
||||
}
|
||||
if input.DefaultMappedModel != nil {
|
||||
group.DefaultMappedModel = *input.DefaultMappedModel
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1413,6 +1433,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil && *input.LoadFactor > 0 {
|
||||
if *input.LoadFactor > 10000 {
|
||||
return nil, errors.New("load_factor must be <= 10000")
|
||||
}
|
||||
account.LoadFactor = input.LoadFactor
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1458,6 +1484,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// 保留 quota_used,防止编辑账号时意外重置配额用量
|
||||
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
|
||||
input.Extra["quota_used"] = oldQuotaUsed
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
@@ -1483,6 +1513,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil {
|
||||
if *input.LoadFactor <= 0 {
|
||||
account.LoadFactor = nil // 0 或负数表示清除
|
||||
} else if *input.LoadFactor > 10000 {
|
||||
return nil, errors.New("load_factor must be <= 10000")
|
||||
} else {
|
||||
account.LoadFactor = input.LoadFactor
|
||||
}
|
||||
}
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
@@ -1616,6 +1655,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.RateMultiplier != nil {
|
||||
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.LoadFactor != nil {
|
||||
if *input.LoadFactor <= 0 {
|
||||
repoUpdates.LoadFactor = nil // 0 或负数表示清除
|
||||
} else if *input.LoadFactor > 10000 {
|
||||
return nil, errors.New("load_factor must be <= 10000")
|
||||
} else {
|
||||
repoUpdates.LoadFactor = input.LoadFactor
|
||||
}
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
@@ -2439,3 +2487,7 @@ func (e *MixedChannelError) Error() string {
|
||||
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
|
||||
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,11 @@ const (
|
||||
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent
|
||||
AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
|
||||
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
|
||||
|
||||
@@ -33,23 +33,25 @@ func NewAnnouncementService(
|
||||
}
|
||||
|
||||
type CreateAnnouncementInput struct {
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UpdateAnnouncementInput struct {
|
||||
Title *string
|
||||
Content *string
|
||||
Status *string
|
||||
Targeting *AnnouncementTargeting
|
||||
StartsAt **time.Time
|
||||
EndsAt **time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
Title *string
|
||||
Content *string
|
||||
Status *string
|
||||
NotifyMode *string
|
||||
Targeting *AnnouncementTargeting
|
||||
StartsAt **time.Time
|
||||
EndsAt **time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
@@ -93,6 +95,14 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
|
||||
return nil, err
|
||||
}
|
||||
|
||||
notifyMode := strings.TrimSpace(input.NotifyMode)
|
||||
if notifyMode == "" {
|
||||
notifyMode = AnnouncementNotifyModeSilent
|
||||
}
|
||||
if !isValidAnnouncementNotifyMode(notifyMode) {
|
||||
return nil, fmt.Errorf("create announcement: invalid notify_mode")
|
||||
}
|
||||
|
||||
if input.StartsAt != nil && input.EndsAt != nil {
|
||||
if !input.StartsAt.Before(*input.EndsAt) {
|
||||
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
|
||||
@@ -100,12 +110,13 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
|
||||
}
|
||||
|
||||
a := &Announcement{
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: status,
|
||||
Targeting: targeting,
|
||||
StartsAt: input.StartsAt,
|
||||
EndsAt: input.EndsAt,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: status,
|
||||
NotifyMode: notifyMode,
|
||||
Targeting: targeting,
|
||||
StartsAt: input.StartsAt,
|
||||
EndsAt: input.EndsAt,
|
||||
}
|
||||
if input.ActorID != nil && *input.ActorID > 0 {
|
||||
a.CreatedBy = input.ActorID
|
||||
@@ -150,6 +161,14 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
|
||||
a.Status = status
|
||||
}
|
||||
|
||||
if input.NotifyMode != nil {
|
||||
notifyMode := strings.TrimSpace(*input.NotifyMode)
|
||||
if !isValidAnnouncementNotifyMode(notifyMode) {
|
||||
return nil, fmt.Errorf("update announcement: invalid notify_mode")
|
||||
}
|
||||
a.NotifyMode = notifyMode
|
||||
}
|
||||
|
||||
if input.Targeting != nil {
|
||||
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
|
||||
if err != nil {
|
||||
@@ -376,3 +395,12 @@ func isValidAnnouncementStatus(status string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isValidAnnouncementNotifyMode(mode string) bool {
|
||||
switch mode {
|
||||
case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3696,6 +3696,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
cw.Write(finalEvents)
|
||||
} else if !processor.MessageStartSent() && !cw.Disconnected() {
|
||||
// 整个流未收到任何可解析的上游数据(全部 SSE 行均无法被 JSON 解析),
|
||||
// 触发 failover 在同账号重试,避免向客户端发出缺少 message_start 的残缺流
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Claude-Stream] empty stream response (no valid events parsed), triggering failover")
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
|
||||
RetryableOnSameAccount: true,
|
||||
}
|
||||
}
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||||
}
|
||||
|
||||
@@ -998,6 +998,46 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_EmptyStream
|
||||
// 验证:上游只返回无法解析的 SSE 行时,触发 UpstreamFailoverError 而不是向客户端发出残缺流
|
||||
func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 所有行均为无法 JSON 解析的内容,ProcessLine 全部返回 nil
|
||||
fmt.Fprintln(pw, "data: not-valid-json")
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, "data: also-invalid")
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_ = pr.Close()
|
||||
|
||||
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.True(t, failoverErr.RetryableOnSameAccount)
|
||||
|
||||
// 客户端不应收到任何 SSE 事件(既无 message_start 也无 message_stop)
|
||||
body := rec.Body.String()
|
||||
require.NotContains(t, body, "event: message_start")
|
||||
require.NotContains(t, body, "event: message_stop")
|
||||
require.NotContains(t, body, "event: message_delta")
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||||
// 验证:context 取消时不注入错误事件
|
||||
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
|
||||
@@ -14,6 +14,18 @@ const (
|
||||
StatusAPIKeyExpired = "expired"
|
||||
)
|
||||
|
||||
// Rate limit window durations
|
||||
const (
|
||||
RateLimitWindow5h = 5 * time.Hour
|
||||
RateLimitWindow1d = 24 * time.Hour
|
||||
RateLimitWindow7d = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||
return windowStart != nil && time.Since(*windowStart) >= duration
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
@@ -98,6 +110,30 @@ func (k *APIKey) GetDaysUntilExpiry() int {
|
||||
return int(duration.Hours() / 24)
|
||||
}
|
||||
|
||||
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage5h() float64 {
|
||||
if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage5h
|
||||
}
|
||||
|
||||
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage1d() float64 {
|
||||
if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage1d
|
||||
}
|
||||
|
||||
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage7d() float64 {
|
||||
if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyListFilters holds optional filtering parameters for listing API keys.
|
||||
type APIKeyListFilters struct {
|
||||
Search string
|
||||
|
||||
@@ -65,6 +65,10 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -245,6 +245,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -302,6 +304,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
245
backend/internal/service/api_key_rate_limit_test.go
Normal file
245
backend/internal/service/api_key_rate_limit_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsWindowExpired(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
start *time.Time
|
||||
duration time.Duration
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil window start",
|
||||
start: nil,
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "active window (started 1h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired window (started 6h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exactly at boundary (started 5h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-5 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active 1d window (started 12h ago)",
|
||||
start: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
duration: RateLimitWindow1d,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired 1d window (started 25h ago)",
|
||||
start: rateLimitTimePtr(now.Add(-25 * time.Hour)),
|
||||
duration: RateLimitWindow1d,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active 7d window (started 3d ago)",
|
||||
start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
|
||||
duration: RateLimitWindow7d,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired 7d window (started 8d ago)",
|
||||
start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
|
||||
duration: RateLimitWindow7d,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsWindowExpired(tt.start, tt.duration)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "all windows expired",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: nil,
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "zero usage with active windows",
|
||||
key: APIKey{
|
||||
Usage5h: 0,
|
||||
Usage1d: 0,
|
||||
Usage7d: 0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.key.EffectiveUsage5h(); got != tt.want5h {
|
||||
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
|
||||
}
|
||||
if got := tt.key.EffectiveUsage1d(); got != tt.want1d {
|
||||
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
|
||||
}
|
||||
if got := tt.key.EffectiveUsage7d(); got != tt.want7d {
|
||||
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data APIKeyRateLimitData
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
},
|
||||
{
|
||||
name: "all windows expired",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: nil,
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.data.EffectiveUsage5h(); got != tt.want5h {
|
||||
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
|
||||
}
|
||||
if got := tt.data.EffectiveUsage1d(); got != tt.want1d {
|
||||
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
|
||||
}
|
||||
if got := tt.data.EffectiveUsage7d(); got != tt.want7d {
|
||||
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func rateLimitTimePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
@@ -86,6 +86,30 @@ type APIKeyRateLimitData struct {
|
||||
Window7dStart *time.Time
|
||||
}
|
||||
|
||||
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 {
|
||||
if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage5h
|
||||
}
|
||||
|
||||
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 {
|
||||
if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage1d
|
||||
}
|
||||
|
||||
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
||||
if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
@@ -565,15 +565,15 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
|
||||
needsReset := false
|
||||
|
||||
// Reset expired windows in-memory for check purposes
|
||||
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
|
||||
if IsWindowExpired(w5h, RateLimitWindow5h) {
|
||||
usage5h = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
|
||||
if IsWindowExpired(w1d, RateLimitWindow1d) {
|
||||
usage1d = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
|
||||
if IsWindowExpired(w7d, RateLimitWindow7d) {
|
||||
usage7d = 0
|
||||
needsReset = true
|
||||
}
|
||||
@@ -589,12 +589,16 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
|
||||
if loader, ok := s.apiKeyRateLimitLoader.(interface {
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
}); ok {
|
||||
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
|
||||
if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invalidate cache so next request loads fresh data
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
|
||||
if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -43,15 +43,24 @@ type BillingCache interface {
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
}
|
||||
|
||||
const (
|
||||
openAIGPT54LongContextInputThreshold = 272000
|
||||
openAIGPT54LongContextInputMultiplier = 2.0
|
||||
openAIGPT54LongContextOutputMultiplier = 1.5
|
||||
)
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
@@ -161,6 +170,35 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
@@ -189,12 +227,30 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
}
|
||||
// Claude 未知型号统一回退到 Sonnet,避免计费中断。
|
||||
if strings.Contains(modelLower, "claude") {
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
|
||||
return s.fallbackPrices["gemini-3.1-pro"]
|
||||
}
|
||||
|
||||
// 默认使用Sonnet价格
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
|
||||
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
||||
normalized := normalizeCodexModel(modelLower)
|
||||
switch normalized {
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.3-codex":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
return s.fallbackPrices["gpt-5.1-codex"]
|
||||
case "gpt-5.1":
|
||||
return s.fallbackPrices["gpt-5.1"]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格配置
|
||||
@@ -212,15 +268,18 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
}, nil
|
||||
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,7 +287,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
fallback := s.getFallbackPricing(model)
|
||||
if fallback != nil {
|
||||
log.Printf("[Billing] Using fallback pricing for model: %s", model)
|
||||
return fallback, nil
|
||||
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
@@ -242,12 +301,18 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
}
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
@@ -279,6 +344,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
if !isOpenAIGPT54Model(model) {
|
||||
return pricing
|
||||
}
|
||||
if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 {
|
||||
return pricing
|
||||
}
|
||||
cloned := *pricing
|
||||
if cloned.LongContextInputThreshold <= 0 {
|
||||
cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold
|
||||
}
|
||||
if cloned.LongContextInputMultiplier <= 0 {
|
||||
cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier
|
||||
}
|
||||
if cloned.LongContextOutputMultiplier <= 0 {
|
||||
cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool {
|
||||
if pricing == nil || pricing.LongContextInputThreshold <= 0 {
|
||||
return false
|
||||
}
|
||||
if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 {
|
||||
return false
|
||||
}
|
||||
totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens
|
||||
return totalInputTokens > pricing.LongContextInputThreshold
|
||||
}
|
||||
|
||||
func isOpenAIGPT54Model(model string) bool {
|
||||
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
|
||||
return normalized == "gpt-5.4"
|
||||
}
|
||||
|
||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
|
||||
@@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) {
|
||||
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
|
||||
func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
|
||||
@@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-unknown-model")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, pricing)
|
||||
require.Contains(t, err.Error(), "pricing not found")
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 300000,
|
||||
OutputTokens: 4000,
|
||||
}
|
||||
|
||||
cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0
|
||||
expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5
|
||||
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expectedInput float64
|
||||
expectNilPricing bool
|
||||
}{
|
||||
{name: "empty model", model: " ", expectNilPricing: true},
|
||||
{name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6},
|
||||
{name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6},
|
||||
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
|
||||
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
|
||||
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
|
||||
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
|
||||
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
|
||||
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
|
||||
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
|
||||
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
|
||||
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
|
||||
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pricing := svc.getFallbackPricing(tt.model)
|
||||
if tt.expectNilPricing {
|
||||
require.Nil(t, pricing)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
|
||||
@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
|
||||
account: &Account{
|
||||
ID: 14,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
account: &Account{
|
||||
|
||||
@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
|
||||
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
|
||||
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
|
||||
|
||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
require.True(t, ok)
|
||||
bodyBytes, ok := rawBody.([]byte)
|
||||
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
||||
require.Equal(t, body, bodyBytes)
|
||||
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
||||
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
|
||||
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
|
||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
||||
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
modelMapping map[string]any // nil = 不配置映射
|
||||
expectedModel string
|
||||
endpoint string // "messages" or "count_tokens"
|
||||
}{
|
||||
{
|
||||
name: "Forward: 无映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: nil,
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 空映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 模型不在映射表中时不改写",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 精确匹配映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 通配符映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 无映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: nil,
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 模型不在映射表中时不改写",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 精确匹配映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 通配符映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: tt.model,
|
||||
}
|
||||
|
||||
credentials := map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
}
|
||||
if tt.modelMapping != nil {
|
||||
credentials["model_mapping"] = tt.modelMapping
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Name: "edge-case-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: credentials,
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
if tt.endpoint == "messages" {
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
parsed.Stream = false
|
||||
|
||||
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||
} else {
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
|
||||
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
// 包含复杂字段的请求体:system、thinking、messages
|
||||
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Name: "preserve-fields-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
sentBody := upstream.lastBody
|
||||
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
|
||||
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
|
||||
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
|
||||
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
|
||||
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
|
||||
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||
// 确保空模型名不会触发映射逻辑
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "", // 空模型
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":10}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Name: "empty-model-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
// 空模型名时,body 应原样透传,不应触发映射
|
||||
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
|
||||
|
||||
|
||||
@@ -501,33 +501,34 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -582,6 +583,13 @@ func NewGatewayService(
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
svc.userGroupRateCache,
|
||||
userGroupRateTTL,
|
||||
&svc.userGroupRateSF,
|
||||
"service.gateway",
|
||||
)
|
||||
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
|
||||
return svc
|
||||
@@ -1228,6 +1236,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
|
||||
continue
|
||||
}
|
||||
// 配额检查
|
||||
if !s.isAccountSchedulableForQuota(account) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
|
||||
filteredWindowCost++
|
||||
@@ -1260,6 +1272,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
||||
@@ -1311,7 +1324,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, acc := range routingCandidates {
|
||||
routingLoads = append(routingLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
|
||||
@@ -1416,6 +1429,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(account) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
|
||||
@@ -1480,6 +1494,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
// 配额检查
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
@@ -1499,7 +1517,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
MaxConcurrency: acc.EffectiveLoadFactor(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2113,6 +2131,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
|
||||
// 仅适用于配置了 quota_limit 的 apikey 类型账号
|
||||
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
return true
|
||||
}
|
||||
return !account.IsQuotaExceeded()
|
||||
}
|
||||
|
||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示可调度,false 表示不可调度
|
||||
@@ -2590,7 +2617,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2644,6 +2671,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2700,7 +2730,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2743,6 +2773,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2818,7 +2851,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
@@ -2874,6 +2907,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -2930,7 +2966,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@@ -2975,6 +3011,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
@@ -3289,6 +3328,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformSora {
|
||||
return s.isSoraModelSupportedByAccount(account, requestedModel)
|
||||
}
|
||||
// OpenAI 透传模式:仅替换认证,允许所有模型
|
||||
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
|
||||
return true
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
@@ -3889,7 +3932,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
if passthroughModel != "" {
|
||||
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
|
||||
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
passthroughModel = mappedModel
|
||||
}
|
||||
}
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
@@ -4574,7 +4626,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
@@ -4954,7 +5006,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5377,6 +5429,11 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
return m
|
||||
}
|
||||
|
||||
// ChatGPT 内部 API 风格:{"detail":"..."}
|
||||
if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" {
|
||||
return d
|
||||
}
|
||||
|
||||
// 兜底:尝试顶层 message
|
||||
return gjson.GetBytes(body, "message").String()
|
||||
}
|
||||
@@ -6292,63 +6349,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
}
|
||||
|
||||
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||
if s == nil || userID <= 0 || groupID <= 0 {
|
||||
if s == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier
|
||||
}
|
||||
}
|
||||
resolver := s.userGroupRateResolver
|
||||
if resolver == nil {
|
||||
resolver = newUserGroupRateResolver(
|
||||
s.userGroupRateRepo,
|
||||
s.userGroupRateCache,
|
||||
resolveUserGroupRateCacheTTL(s.cfg),
|
||||
&s.userGroupRateSF,
|
||||
"service.gateway",
|
||||
)
|
||||
}
|
||||
if s.userGroupRateRepo == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
userGroupRateCacheMissTotal.Add(1)
|
||||
|
||||
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userGroupRateCacheLoadTotal.Add(1)
|
||||
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
|
||||
if repoErr != nil {
|
||||
return nil, repoErr
|
||||
}
|
||||
multiplier := groupDefaultMultiplier
|
||||
if userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
if s.userGroupRateCache != nil {
|
||||
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
|
||||
}
|
||||
return multiplier, nil
|
||||
})
|
||||
if shared {
|
||||
userGroupRateCacheSFSharedTotal.Add(1)
|
||||
}
|
||||
if err != nil {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
multiplier, ok := value.(float64)
|
||||
if !ok {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
return multiplier
|
||||
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
@@ -6370,6 +6384,89 @@ type APIKeyQuotaUpdater interface {
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
// postUsageBillingParams 统一扣费所需的参数
|
||||
type postUsageBillingParams struct {
|
||||
Cost *CostBreakdown
|
||||
User *User
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||
// - 订阅/余额扣费
|
||||
// - API Key 配额更新
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 更新账号最近使用时间
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
}
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
deferredService *DeferredService
|
||||
}
|
||||
|
||||
func (s *GatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
@@ -6533,45 +6630,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// 更新 API Key 配额(如果设置了配额限制)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6731,44 +6804,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
// API Key 独立配额扣费
|
||||
if input.APIKeyService != nil && apiKey.Quota > 0 {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6781,7 +6831,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
|
||||
passthroughBody := parsed.Body
|
||||
if reqModel := parsed.Model; reqModel != "" {
|
||||
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
|
||||
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
@@ -7072,7 +7129,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
@@ -7119,7 +7176,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
|
||||
account: &Account{
|
||||
ID: 105,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
|
||||
@@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
|
||||
@@ -57,6 +57,10 @@ type Group struct {
|
||||
// 分组排序
|
||||
SortOrder int
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
DefaultMappedModel string
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ import (
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
|
||||
// 匹配 user_id 格式:
|
||||
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
|
||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
|
||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
|
||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
@@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 匹配格式: user_{64位hex}_account__session_{uuid}
|
||||
// 匹配格式:
|
||||
// 旧格式: user_{64位hex}_account__session_{uuid}
|
||||
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
|
||||
matches := userIDRegex.FindStringSubmatch(userID)
|
||||
if matches == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionTail := matches[1] // 原始session UUID
|
||||
// matches[1] = account UUID (可能为空), matches[2] = session UUID
|
||||
sessionTail := matches[2] // 原始session UUID
|
||||
|
||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user