mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 17:00:20 +08:00
Compare commits
107 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
bcb6444f89 | ||
|
|
c2b14693b4 | ||
|
|
92d35409de | ||
|
|
351a08f813 | ||
|
|
a58dc787a9 | ||
|
|
7079edc2d0 | ||
|
|
da89583ccc | ||
|
|
a42a1f08e9 | ||
|
|
ebd5253e22 | ||
|
|
6411645ffc | ||
|
|
c0c322ba16 | ||
|
|
d35c5cd491 | ||
|
|
7a353028e7 | ||
|
|
2d8d3b7857 | ||
|
|
4190293b07 | ||
|
|
421b4c0aff | ||
|
|
cd69a7cb85 | ||
|
|
0c9ba9e86c | ||
|
|
1b4d2a41c9 | ||
|
|
0787d2b47a | ||
|
|
97bf1d85ab | ||
|
|
207a493fab | ||
|
|
1f3f9e131e | ||
|
|
4ddedfaaf9 | ||
|
|
3ebebef95f | ||
|
|
9f7ad47598 | ||
|
|
3c83cd8be2 | ||
|
|
963b3b768c | ||
|
|
f6709fb5d6 | ||
|
|
921599948b | ||
|
|
5df3cafa99 | ||
|
|
1a2143c1fe | ||
|
|
dd25281305 | ||
|
|
49d0301dde | ||
|
|
d90e56eb45 | ||
|
|
838ada8864 | ||
|
|
65a106792a | ||
|
|
ee4bfcbb81 | ||
|
|
a087f089b8 | ||
|
|
afbe8bf001 | ||
|
|
2a3ef0be06 | ||
|
|
3403909354 | ||
|
|
005d0c5f53 | ||
|
|
8aaaeb29cc | ||
|
|
230f8abd04 | ||
|
|
a18bbb5f2f | ||
|
|
60fce4f1dc | ||
|
|
9af65efcdb | ||
|
|
bc194a7d8c | ||
|
|
c28f691f32 | ||
|
|
ff1f114989 | ||
|
|
cac230206d | ||
|
|
79ae15d5e8 |
6
.github/workflows/backend-ci.yml
vendored
6
.github/workflows/backend-ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@@ -38,10 +38,10 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
|
||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -93,20 +93,13 @@ linters:
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
# Temporarily disable style checks to allow CI to pass
|
||||
# "all" enables every SA/ST/S/QF check; only list the ones to disable.
|
||||
checks:
|
||||
- all
|
||||
- -ST1000 # Package comment format
|
||||
@@ -114,489 +107,19 @@ linters:
|
||||
- -ST1020 # Comment on exported method format
|
||||
- -ST1021 # Comment on exported type format
|
||||
- -ST1022 # Comment on exported variable format
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# NOTE: ST1020, ST1021, ST1022 removed (disabled above)
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
field-writes-are-uses: true
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
exported-fields-are-used: true
|
||||
# Default: true
|
||||
parameters-are-used: true
|
||||
# Default: true
|
||||
local-variables-are-used: false
|
||||
# Default: true — must be true, ent generates 130K+ lines of code
|
||||
generated-is-used: true
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
|
||||
@@ -162,9 +162,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -229,7 +229,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -25,6 +25,8 @@ type Announcement struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
// 状态: draft, active, archived
|
||||
Status string `json:"status,omitempty"`
|
||||
// 通知模式: silent(仅铃铛), popup(弹窗提醒)
|
||||
NotifyMode string `json:"notify_mode,omitempty"`
|
||||
// 展示条件(JSON 规则)
|
||||
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
|
||||
// 开始展示时间(为空表示立即生效)
|
||||
@@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new([]byte)
|
||||
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
|
||||
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode:
|
||||
values[i] = new(sql.NullString)
|
||||
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case announcement.FieldNotifyMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field notify_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.NotifyMode = value.String
|
||||
}
|
||||
case announcement.FieldTargeting:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field targeting", values[i])
|
||||
@@ -213,6 +221,9 @@ func (_m *Announcement) String() string {
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("notify_mode=")
|
||||
builder.WriteString(_m.NotifyMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("targeting=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
FieldContent = "content"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldNotifyMode holds the string denoting the notify_mode field in the database.
|
||||
FieldNotifyMode = "notify_mode"
|
||||
// FieldTargeting holds the string denoting the targeting field in the database.
|
||||
FieldTargeting = "targeting"
|
||||
// FieldStartsAt holds the string denoting the starts_at field in the database.
|
||||
@@ -53,6 +55,7 @@ var Columns = []string{
|
||||
FieldTitle,
|
||||
FieldContent,
|
||||
FieldStatus,
|
||||
FieldNotifyMode,
|
||||
FieldTargeting,
|
||||
FieldStartsAt,
|
||||
FieldEndsAt,
|
||||
@@ -81,6 +84,10 @@ var (
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultNotifyMode holds the default value on creation for the "notify_mode" field.
|
||||
DefaultNotifyMode string
|
||||
// NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
NotifyModeValidator func(string) error
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
@@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByNotifyMode orders the results by the notify_mode field.
|
||||
func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotifyMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStartsAt orders the results by the starts_at field.
|
||||
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
|
||||
|
||||
@@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ.
|
||||
func NotifyMode(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
|
||||
func StartsAt(v time.Time) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
|
||||
@@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// NotifyModeEQ applies the EQ predicate on the "notify_mode" field.
|
||||
func NotifyModeEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field.
|
||||
func NotifyModeNEQ(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeIn applies the In predicate on the "notify_mode" field.
|
||||
func NotifyModeIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field.
|
||||
func NotifyModeNotIn(vs ...string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...))
|
||||
}
|
||||
|
||||
// NotifyModeGT applies the GT predicate on the "notify_mode" field.
|
||||
func NotifyModeGT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeGTE applies the GTE predicate on the "notify_mode" field.
|
||||
func NotifyModeGTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLT applies the LT predicate on the "notify_mode" field.
|
||||
func NotifyModeLT(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeLTE applies the LTE predicate on the "notify_mode" field.
|
||||
func NotifyModeLTE(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContains applies the Contains predicate on the "notify_mode" field.
|
||||
func NotifyModeContains(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasPrefix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field.
|
||||
func NotifyModeHasSuffix(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field.
|
||||
func NotifyModeEqualFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field.
|
||||
func NotifyModeContainsFold(v string) predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v))
|
||||
}
|
||||
|
||||
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
|
||||
func TargetingIsNil() predicate.Announcement {
|
||||
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
|
||||
|
||||
@@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate {
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate {
|
||||
if v != nil {
|
||||
_c.SetNotifyMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate {
|
||||
_c.mutation.SetTargeting(v)
|
||||
@@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() {
|
||||
v := announcement.DefaultStatus
|
||||
_c.mutation.SetStatus(v)
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
v := announcement.DefaultNotifyMode
|
||||
_c.mutation.SetNotifyMode(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
v := announcement.DefaultCreatedAt()
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.NotifyMode(); !ok {
|
||||
return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec)
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
_node.Status = value
|
||||
}
|
||||
if value, ok := _c.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
_node.NotifyMode = value
|
||||
}
|
||||
if value, ok := _c.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
_node.Targeting = value
|
||||
@@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldNotifyMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert {
|
||||
u.SetExcluded(announcement.FieldNotifyMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert {
|
||||
u.Set(announcement.FieldTargeting, v)
|
||||
@@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
@@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.SetNotifyMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
|
||||
func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
s.UpdateNotifyMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk {
|
||||
return u.Update(func(s *AnnouncementUpsert) {
|
||||
|
||||
@@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
@@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetNotifyMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
|
||||
func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetNotifyMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
|
||||
_u.mutation.SetTargeting(v)
|
||||
@@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.NotifyMode(); ok {
|
||||
if err := announcement.NotifyModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.NotifyMode(); ok {
|
||||
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Targeting(); ok {
|
||||
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ type Group struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -186,13 +190,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||
values[i] = new(sql.NullString)
|
||||
case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -415,6 +419,18 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.SortOrder = int(value.Int64)
|
||||
}
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AllowMessagesDispatch = value.Bool
|
||||
}
|
||||
case group.FieldDefaultMappedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DefaultMappedModel = value.String
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -608,6 +624,12 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sort_order=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("allow_messages_dispatch=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -75,6 +75,10 @@ const (
|
||||
FieldSupportedModelScopes = "supported_model_scopes"
|
||||
// FieldSortOrder holds the string denoting the sort_order field in the database.
|
||||
FieldSortOrder = "sort_order"
|
||||
// FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database.
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -180,6 +184,8 @@ var Columns = []string{
|
||||
FieldMcpXMLInject,
|
||||
FieldSupportedModelScopes,
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldDefaultMappedModel,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -247,6 +253,12 @@ var (
|
||||
DefaultSupportedModelScopes []string
|
||||
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
|
||||
DefaultSortOrder int
|
||||
// DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field.
|
||||
DefaultAllowMessagesDispatch bool
|
||||
// DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field.
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
DefaultMappedModelValidator func(string) error
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -397,6 +409,16 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field.
|
||||
func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDefaultMappedModel orders the results by the default_mapped_model field.
|
||||
func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -195,6 +195,16 @@ func SortOrder(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ.
|
||||
func AllowMessagesDispatch(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ.
|
||||
func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1470,6 +1480,81 @@ func SortOrderLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field.
|
||||
func AllowMessagesDispatchNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNEQ(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelNotIn(vs ...string) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelGTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLT(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelLTE(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContains(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasPrefix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelHasSuffix(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelEqualFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field.
|
||||
func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -424,6 +424,34 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate {
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate {
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -613,6 +641,14 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultSortOrder
|
||||
_c.mutation.SetSortOrder(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
v := group.DefaultAllowMessagesDispatch
|
||||
_c.mutation.SetAllowMessagesDispatch(v)
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -683,6 +719,17 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.SortOrder(); !ok {
|
||||
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
|
||||
return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
|
||||
return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -830,6 +877,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
_node.SortOrder = value
|
||||
}
|
||||
if value, ok := _c.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
_node.AllowMessagesDispatch = value
|
||||
}
|
||||
if value, ok := _c.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1520,6 +1575,30 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldAllowMessagesDispatch, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldAllowMessagesDispatch)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert {
|
||||
u.Set(group.FieldDefaultMappedModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldDefaultMappedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2188,6 +2267,34 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -3022,6 +3129,34 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowMessagesDispatch(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowMessagesDispatch()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetDefaultMappedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateDefaultMappedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -625,6 +625,34 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -910,6 +938,11 @@ func (_u *GroupUpdate) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1110,6 +1143,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -2014,6 +2053,34 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetAllowMessagesDispatch(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAllowMessagesDispatch(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne {
|
||||
_u.mutation.SetDefaultMappedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDefaultMappedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2312,6 +2379,11 @@ func (_u *GroupUpdateOne) check() error {
|
||||
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
if err := group.DefaultMappedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2529,6 +2601,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
|
||||
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -251,6 +251,7 @@ var (
|
||||
{Name: "title", Type: field.TypeString, Size: 200},
|
||||
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
|
||||
{Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"},
|
||||
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -273,17 +274,17 @@ var (
|
||||
{
|
||||
Name: "announcement_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[9]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_starts_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[5]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
},
|
||||
{
|
||||
Name: "announcement_ends_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AnnouncementsColumns[6]},
|
||||
Columns: []*schema.Column{AnnouncementsColumns[7]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -407,6 +408,8 @@ var (
|
||||
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
|
||||
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -5167,6 +5167,7 @@ type AnnouncementMutation struct {
|
||||
title *string
|
||||
content *string
|
||||
status *string
|
||||
notify_mode *string
|
||||
targeting *domain.AnnouncementTargeting
|
||||
starts_at *time.Time
|
||||
ends_at *time.Time
|
||||
@@ -5391,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() {
|
||||
m.status = nil
|
||||
}
|
||||
|
||||
// SetNotifyMode sets the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) SetNotifyMode(s string) {
|
||||
m.notify_mode = &s
|
||||
}
|
||||
|
||||
// NotifyMode returns the value of the "notify_mode" field in the mutation.
|
||||
func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) {
|
||||
v := m.notify_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity.
|
||||
// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldNotifyMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err)
|
||||
}
|
||||
return oldValue.NotifyMode, nil
|
||||
}
|
||||
|
||||
// ResetNotifyMode resets all changes to the "notify_mode" field.
|
||||
func (m *AnnouncementMutation) ResetNotifyMode() {
|
||||
m.notify_mode = nil
|
||||
}
|
||||
|
||||
// SetTargeting sets the "targeting" field.
|
||||
func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) {
|
||||
m.targeting = &dt
|
||||
@@ -5838,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *AnnouncementMutation) Fields() []string {
|
||||
fields := make([]string, 0, 10)
|
||||
fields := make([]string, 0, 11)
|
||||
if m.title != nil {
|
||||
fields = append(fields, announcement.FieldTitle)
|
||||
}
|
||||
@@ -5848,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string {
|
||||
if m.status != nil {
|
||||
fields = append(fields, announcement.FieldStatus)
|
||||
}
|
||||
if m.notify_mode != nil {
|
||||
fields = append(fields, announcement.FieldNotifyMode)
|
||||
}
|
||||
if m.targeting != nil {
|
||||
fields = append(fields, announcement.FieldTargeting)
|
||||
}
|
||||
@@ -5883,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Content()
|
||||
case announcement.FieldStatus:
|
||||
return m.Status()
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.NotifyMode()
|
||||
case announcement.FieldTargeting:
|
||||
return m.Targeting()
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5912,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V
|
||||
return m.OldContent(ctx)
|
||||
case announcement.FieldStatus:
|
||||
return m.OldStatus(ctx)
|
||||
case announcement.FieldNotifyMode:
|
||||
return m.OldNotifyMode(ctx)
|
||||
case announcement.FieldTargeting:
|
||||
return m.OldTargeting(ctx)
|
||||
case announcement.FieldStartsAt:
|
||||
@@ -5956,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetStatus(v)
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetNotifyMode(v)
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
v, ok := value.(domain.AnnouncementTargeting)
|
||||
if !ok {
|
||||
@@ -6123,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error {
|
||||
case announcement.FieldStatus:
|
||||
m.ResetStatus()
|
||||
return nil
|
||||
case announcement.FieldNotifyMode:
|
||||
m.ResetNotifyMode()
|
||||
return nil
|
||||
case announcement.FieldTargeting:
|
||||
m.ResetTargeting()
|
||||
return nil
|
||||
@@ -8196,6 +8250,8 @@ type GroupMutation struct {
|
||||
appendsupported_model_scopes []string
|
||||
sort_order *int
|
||||
addsort_order *int
|
||||
allow_messages_dispatch *bool
|
||||
default_mapped_model *string
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -9940,6 +9996,78 @@ func (m *GroupMutation) ResetSortOrder() {
|
||||
m.addsort_order = nil
|
||||
}
|
||||
|
||||
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
|
||||
m.allow_messages_dispatch = &b
|
||||
}
|
||||
|
||||
// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
|
||||
func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
|
||||
v := m.allow_messages_dispatch
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
|
||||
}
|
||||
return oldValue.AllowMessagesDispatch, nil
|
||||
}
|
||||
|
||||
// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
|
||||
func (m *GroupMutation) ResetAllowMessagesDispatch() {
|
||||
m.allow_messages_dispatch = nil
|
||||
}
|
||||
|
||||
// SetDefaultMappedModel sets the "default_mapped_model" field.
|
||||
func (m *GroupMutation) SetDefaultMappedModel(s string) {
|
||||
m.default_mapped_model = &s
|
||||
}
|
||||
|
||||
// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
|
||||
func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
|
||||
v := m.default_mapped_model
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
|
||||
}
|
||||
return oldValue.DefaultMappedModel, nil
|
||||
}
|
||||
|
||||
// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
|
||||
func (m *GroupMutation) ResetDefaultMappedModel() {
|
||||
m.default_mapped_model = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -10298,7 +10426,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 31)
|
||||
fields := make([]string, 0, 32)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -10389,6 +10517,12 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.sort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
if m.allow_messages_dispatch != nil {
|
||||
fields = append(fields, group.FieldAllowMessagesDispatch)
|
||||
}
|
||||
if m.default_mapped_model != nil {
|
||||
fields = append(fields, group.FieldDefaultMappedModel)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -10457,6 +10591,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.SupportedModelScopes()
|
||||
case group.FieldSortOrder:
|
||||
return m.SortOrder()
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.AllowMessagesDispatch()
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.DefaultMappedModel()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -10526,6 +10664,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldSupportedModelScopes(ctx)
|
||||
case group.FieldSortOrder:
|
||||
return m.OldSortOrder(ctx)
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
return m.OldAllowMessagesDispatch(ctx)
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -10745,6 +10887,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetSortOrder(v)
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAllowMessagesDispatch(v)
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetDefaultMappedModel(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -11172,6 +11328,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldSortOrder:
|
||||
m.ResetSortOrder()
|
||||
return nil
|
||||
case group.FieldAllowMessagesDispatch:
|
||||
m.ResetAllowMessagesDispatch()
|
||||
return nil
|
||||
case group.FieldDefaultMappedModel:
|
||||
m.ResetDefaultMappedModel()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -277,12 +277,18 @@ func init() {
|
||||
announcement.DefaultStatus = announcementDescStatus.Default.(string)
|
||||
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
|
||||
// announcementDescNotifyMode is the schema descriptor for notify_mode field.
|
||||
announcementDescNotifyMode := announcementFields[3].Descriptor()
|
||||
// announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field.
|
||||
announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string)
|
||||
// announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
|
||||
announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error)
|
||||
// announcementDescCreatedAt is the schema descriptor for created_at field.
|
||||
announcementDescCreatedAt := announcementFields[8].Descriptor()
|
||||
announcementDescCreatedAt := announcementFields[9].Descriptor()
|
||||
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
|
||||
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
announcementDescUpdatedAt := announcementFields[9].Descriptor()
|
||||
announcementDescUpdatedAt := announcementFields[10].Descriptor()
|
||||
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
|
||||
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
@@ -447,6 +453,16 @@ func init() {
|
||||
groupDescSortOrder := groupFields[26].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
|
||||
@@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementStatusDraft).
|
||||
Comment("状态: draft, active, archived"),
|
||||
field.String("notify_mode").
|
||||
MaxLen(20).
|
||||
Default(domain.AnnouncementNotifyModeSilent).
|
||||
Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"),
|
||||
field.JSON("targeting", domain.AnnouncementTargeting{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
|
||||
@@ -148,6 +148,15 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
Default(false).
|
||||
Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"),
|
||||
field.String("default_mapped_model").
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.25.7
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -38,8 +39,6 @@ require (
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.40.0
|
||||
google.golang.org/grpc v1.75.1
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
@@ -53,7 +52,6 @@ require (
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
|
||||
@@ -109,7 +107,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
@@ -169,6 +166,7 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
@@ -178,8 +176,8 @@ require (
|
||||
golang.org/x/mod v0.32.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
@@ -124,8 +124,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
@@ -182,7 +180,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -202,8 +199,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -286,10 +281,6 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
|
||||
@@ -1402,7 +1402,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
|
||||
@@ -13,6 +13,11 @@ const (
|
||||
AnnouncementStatusArchived = "archived"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = "silent"
|
||||
AnnouncementNotifyModePopup = "popup"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = "subscription"
|
||||
AnnouncementConditionTypeBalance = "balance"
|
||||
@@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error {
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
|
||||
@@ -122,7 +122,7 @@ type UpdateAccountRequest struct {
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -288,48 +288,32 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 窗口费用获取:lite 模式从快照缓存读取,非 lite 模式执行 PostgreSQL 查询后写入缓存
|
||||
// 始终获取窗口费用(PostgreSQL 聚合查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
if lite {
|
||||
// lite 模式:尝试从快照缓存读取
|
||||
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
|
||||
if cached, ok := accountWindowCostCache.Get(cacheKey); ok {
|
||||
if costs, ok := cached.Payload.(map[int64]float64); ok {
|
||||
windowCosts = costs
|
||||
}
|
||||
}
|
||||
// 缓存未命中则 windowCosts 保持 nil(仅发生在服务刚启动时)
|
||||
} else {
|
||||
// 非 lite 模式:执行 PostgreSQL 聚合查询(高开销)
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
_ = g.Wait()
|
||||
|
||||
// 查询完毕后写入快照缓存,供 lite 模式使用
|
||||
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
|
||||
accountWindowCostCache.Set(cacheKey, windowCosts)
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
@@ -676,6 +660,42 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var accountWindowCostCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func buildWindowCostCacheKey(accountIDs []int64) string {
|
||||
if len(accountIDs) == 0 {
|
||||
return "accounts_window_cost_empty"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(accountIDs) * 6)
|
||||
_, _ = b.WriteString("accounts_window_cost:")
|
||||
for i, id := range accountIDs {
|
||||
if i > 0 {
|
||||
_ = b.WriteByte(',')
|
||||
}
|
||||
_, _ = b.WriteString(strconv.FormatInt(id, 10))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A
|
||||
}
|
||||
|
||||
type CreateAnnouncementRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
}
|
||||
|
||||
type UpdateAnnouncementRequest struct {
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
}
|
||||
|
||||
// List handles listing announcements with filters
|
||||
@@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.CreateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil && *req.StartsAt > 0 {
|
||||
@@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.UpdateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
NotifyMode: req.NotifyMode,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil {
|
||||
|
||||
@@ -53,6 +53,9 @@ type CreateGroupRequest struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -88,6 +91,9 @@ type UpdateGroupRequest struct {
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -203,6 +209,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -254,6 +262,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -1348,6 +1348,63 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type Announcement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
|
||||
@@ -25,9 +26,10 @@ type Announcement struct {
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
NotifyMode string `json:"notify_mode"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
@@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement {
|
||||
return nil
|
||||
}
|
||||
return &Announcement{
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
NotifyMode: a.NotifyMode,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement
|
||||
return nil
|
||||
}
|
||||
return &UserAnnouncement{
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
NotifyMode: a.Announcement.NotifyMode,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -89,15 +89,28 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
RateLimit5h: k.RateLimit5h,
|
||||
RateLimit1d: k.RateLimit1d,
|
||||
RateLimit7d: k.RateLimit7d,
|
||||
Usage5h: k.Usage5h,
|
||||
Usage1d: k.Usage1d,
|
||||
Usage7d: k.Usage7d,
|
||||
Usage5h: k.EffectiveUsage5h(),
|
||||
Usage1d: k.EffectiveUsage1d(),
|
||||
Usage7d: k.EffectiveUsage7d(),
|
||||
Window5hStart: k.Window5hStart,
|
||||
Window1dStart: k.Window1dStart,
|
||||
Window7dStart: k.Window7dStart,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -126,6 +139,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
@@ -164,6 +178,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -253,11 +268,19 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -473,6 +496,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
|
||||
@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -161,6 +161,13 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -96,6 +99,9 @@ type Group struct {
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -112,6 +118,9 @@ type AdminGroup struct {
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -187,8 +196,12 @@ type Account struct {
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -309,6 +322,8 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -971,34 +971,46 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
if err == nil && rateLimitData != nil {
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.Usage5h
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.Usage1d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.Usage7d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
@@ -155,6 +155,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -118,6 +119,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
sessionHashBody := body
|
||||
if service.IsOpenAIResponsesCompactPathForTest(c) {
|
||||
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
|
||||
c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed)
|
||||
}
|
||||
normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body)
|
||||
if compactErr != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body")
|
||||
return
|
||||
}
|
||||
if normalizedCompact {
|
||||
body = normalizedCompactBody
|
||||
}
|
||||
}
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
@@ -193,11 +208,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -245,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -274,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -305,6 +341,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
@@ -424,6 +463,352 @@ func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, sta
|
||||
log.Warn("codex.remote_compact.failed")
|
||||
}
|
||||
|
||||
// Messages handles Anthropic Messages API requests routed to OpenAI platform.
|
||||
// POST /v1/messages (when group platform is OpenAI)
|
||||
func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverAnthropicMessagesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.messages",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查分组是否允许 /v1/messages 调度
|
||||
if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch {
|
||||
h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group does not allow /v1/messages dispatch")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
|
||||
c.Set("openai_messages_fallback_model", "")
|
||||
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_messages.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_messages_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_messages.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_messages.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicErrorResponse writes an error in Anthropic Messages API format.
|
||||
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// anthropicStreamingAwareError handles errors that may occur during streaming,
|
||||
// using Anthropic SSE error format.
|
||||
func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errPayload, _ := json.Marshal(gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.anthropicErrorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format.
|
||||
func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode)
|
||||
h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written.
|
||||
func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
@@ -840,6 +1225,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
@@ -901,6 +1289,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart
|
||||
)
|
||||
}
|
||||
|
||||
// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages
|
||||
// handler and returns an Anthropic-formatted error response.
|
||||
func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := streamStarted != nil && *streamStarted
|
||||
requestLogger(c, "handler.openai_gateway.messages").Error(
|
||||
"openai.messages_panic_recovered",
|
||||
zap.Bool("stream_started", started),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
if !started {
|
||||
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
@@ -1106,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
|
||||
@@ -2207,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -445,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -119,23 +119,33 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// Finish 结束处理,返回最终事件和用量
|
||||
// Finish 结束处理,返回最终事件和用量。
|
||||
// 若整个流未收到任何可解析的上游数据(messageStartSent == false),
|
||||
// 则不补发任何结束事件,防止客户端收到没有 message_start 的残缺流。
|
||||
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
if !p.messageStartSent {
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
}
|
||||
|
||||
// MessageStartSent 报告流中是否已发出过 message_start 事件(即是否收到过有效的上游数据)
|
||||
func (p *StreamingProcessor) MessageStartSent() bool {
|
||||
return p.messageStartSent
|
||||
}
|
||||
|
||||
// emitMessageStart 发送 message_start 事件
|
||||
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
|
||||
if p.messageStartSent {
|
||||
|
||||
1009
backend/internal/pkg/apicompat/anthropic_responses_test.go
Normal file
1009
backend/internal/pkg/apicompat/anthropic_responses_test.go
Normal file
File diff suppressed because it is too large
Load Diff
417
backend/internal/pkg/apicompat/anthropic_to_responses.go
Normal file
417
backend/internal/pkg/apicompat/anthropic_to_responses.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AnthropicToResponses converts an Anthropic Messages request directly into
|
||||
// a Responses API request. This preserves fields that would be lost in a
|
||||
// Chat Completions intermediary round-trip (e.g. thinking, cache_control,
|
||||
// structured system prompts).
|
||||
func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertAnthropicToResponsesInput(req.System, req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
v := req.MaxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: mapAnthropicEffortToResponses(effort),
|
||||
Summary: "auto",
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if len(req.ToolChoice) > 0 {
|
||||
tc, err := convertAnthropicToolChoiceToResponses(req.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert tool_choice: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertAnthropicToolChoiceToResponses maps Anthropic tool_choice to Responses format.
|
||||
//
|
||||
// {"type":"auto"} → "auto"
|
||||
// {"type":"any"} → "required"
|
||||
// {"type":"none"} → "none"
|
||||
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch tc.Type {
|
||||
case "auto":
|
||||
return json.Marshal("auto")
|
||||
case "any":
|
||||
return json.Marshal("required")
|
||||
case "none":
|
||||
return json.Marshal("none")
|
||||
case "tool":
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": tc.Name},
|
||||
})
|
||||
default:
|
||||
// Pass through unknown types as-is
|
||||
return raw, nil
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToResponsesInput builds the Responses API input items array
|
||||
// from the Anthropic system field and message list.
|
||||
func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
|
||||
// System prompt → system role input item.
|
||||
if len(system) > 0 {
|
||||
sysText, err := parseAnthropicSystemPrompt(system)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sysText != "" {
|
||||
content, _ := json.Marshal(sysText)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Role: "system",
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
items, err := anthropicMsgToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseAnthropicSystemPrompt handles the Anthropic system field which can be
|
||||
// a plain string or an array of text blocks.
|
||||
func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parts []string
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n"), nil
|
||||
}
|
||||
|
||||
// anthropicMsgToResponsesItems converts a single Anthropic message into one
|
||||
// or more Responses API input items.
|
||||
func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
return anthropicUserToResponses(m.Content)
|
||||
case "assistant":
|
||||
return anthropicAssistantToResponses(m.Content)
|
||||
default:
|
||||
return anthropicUserToResponses(m.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items. Image blocks are converted to input_image parts.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
var toolResultImageParts []ResponsesContentPart
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
// Images inside tool_results are extracted separately because the
|
||||
// Responses API function_call_output.output only accepts strings.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
outputText, imageParts := convertToolResultOutput(b)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: outputText,
|
||||
})
|
||||
toolResultImageParts = append(toolResultImageParts, imageParts...)
|
||||
}
|
||||
|
||||
// Remaining text + image blocks → user message with content parts.
|
||||
// Also include images extracted from tool_results so the model can see them.
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
if b.Text != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(b.Source); uri != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, toolResultImageParts...)
|
||||
|
||||
if len(parts) > 0 {
|
||||
content, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// anthropicAssistantToResponses handles an Anthropic assistant message.
|
||||
// Text content → assistant message with output_text parts.
|
||||
// tool_use blocks → function_call items.
|
||||
// thinking blocks → ignored (OpenAI doesn't accept them as input).
|
||||
func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Text content → assistant message with output_text content parts.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: text}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
|
||||
// tool_use → function_call items.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_use" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
fcID := toResponsesCallID(b.ID)
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a
|
||||
// Responses API function_call ID that starts with "fc_".
|
||||
func toResponsesCallID(id string) string {
|
||||
if strings.HasPrefix(id, "fc_") {
|
||||
return id
|
||||
}
|
||||
return "fc_" + id
|
||||
}
|
||||
|
||||
// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix
|
||||
// that was added during request conversion.
|
||||
func fromResponsesCallID(id string) string {
|
||||
if after, ok := strings.CutPrefix(id, "fc_"); ok {
|
||||
// Only strip if the remainder doesn't look like it was already "fc_" prefixed.
|
||||
// E.g. "fc_toolu_xxx" → "toolu_xxx", "fc_call_xxx" → "call_xxx"
|
||||
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
|
||||
return after
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
|
||||
// Returns "" if the source is nil or has no data.
|
||||
func anthropicImageToDataURI(src *AnthropicImageSource) string {
|
||||
if src == nil || src.Data == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType := src.MediaType
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + src.Data
|
||||
}
|
||||
|
||||
// convertToolResultOutput extracts text and image content from a tool_result
|
||||
// block. Returns the text as a string for the function_call_output Output
|
||||
// field, plus any image parts that must be sent in a separate user message
|
||||
// (the Responses API output field only accepts strings).
|
||||
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
|
||||
if len(b.Content) == 0 {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Try plain string content.
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
if s == "" {
|
||||
s = "(empty)"
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Array of content blocks — may contain text and/or images.
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err != nil {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Separate text (for function_call_output) from images (for user message).
|
||||
var textParts []string
|
||||
var imageParts []ResponsesContentPart
|
||||
for _, ib := range inner {
|
||||
switch ib.Type {
|
||||
case "text":
|
||||
if ib.Text != "" {
|
||||
textParts = append(textParts, ib.Text)
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
|
||||
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text := strings.Join(textParts, "\n\n")
|
||||
if text == "" {
|
||||
text = "(empty)"
|
||||
}
|
||||
return text, imageParts
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
// tool_use/tool_result blocks.
|
||||
func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
var parts []string
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
for _, t := range tools {
|
||||
// Anthropic server tools like "web_search_20250305" → OpenAI {"type":"web_search"}
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
out = append(out, ResponsesTool{Type: "web_search"})
|
||||
continue
|
||||
}
|
||||
out = append(out, ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.InputSchema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
516
backend/internal/pkg/apicompat/responses_to_anthropic.go
Normal file
516
backend/internal/pkg/apicompat/responses_to_anthropic.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → AnthropicResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToAnthropic converts a Responses API response directly into an
|
||||
// Anthropic Messages response. Reasoning output items are mapped to thinking
|
||||
// blocks; function_call items become tool_use blocks.
|
||||
func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicResponse {
|
||||
out := &AnthropicResponse{
|
||||
ID: resp.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "reasoning":
|
||||
summaryText := ""
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
summaryText += s.Text
|
||||
}
|
||||
}
|
||||
if summaryText != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: summaryText,
|
||||
})
|
||||
}
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
})
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(item.CallID),
|
||||
Name: item.Name,
|
||||
Input: json.RawMessage(item.Arguments),
|
||||
})
|
||||
case "web_search_call":
|
||||
toolUseID := "srvtoolu_" + item.ID
|
||||
query := ""
|
||||
if item.Action != nil {
|
||||
query = item.Action.Query
|
||||
}
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: inputJSON,
|
||||
})
|
||||
emptyResults, _ := json.Marshal([]struct{}{})
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: emptyResults,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
|
||||
}
|
||||
out.Content = blocks
|
||||
|
||||
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
|
||||
|
||||
if resp.Usage != nil {
|
||||
out.Usage = AnthropicUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "max_tokens"
|
||||
}
|
||||
return "end_turn"
|
||||
case "completed":
|
||||
if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" {
|
||||
return "tool_use"
|
||||
}
|
||||
return "end_turn"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToAnthropicState tracks state for converting a sequence of
|
||||
// Responses SSE events directly into Anthropic SSE events.
|
||||
type ResponsesEventToAnthropicState struct {
|
||||
MessageStartSent bool
|
||||
MessageStopSent bool
|
||||
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
|
||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||
OutputIndexToBlockIdx map[int]int
|
||||
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheReadInputTokens int
|
||||
|
||||
ResponseID string
|
||||
Model string
|
||||
Created int64
|
||||
}
|
||||
|
||||
// NewResponsesEventToAnthropicState returns an initialised stream state.
|
||||
func NewResponsesEventToAnthropicState() *ResponsesEventToAnthropicState {
|
||||
return &ResponsesEventToAnthropicState{
|
||||
OutputIndexToBlockIdx: make(map[int]int),
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToAnthropicEvents converts a single Responses SSE event into
|
||||
// zero or more Anthropic SSE events, updating state as it goes.
|
||||
func ResponsesEventToAnthropicEvents(
|
||||
evt *ResponsesStreamEvent,
|
||||
state *ResponsesEventToAnthropicState,
|
||||
) []AnthropicStreamEvent {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToAnthHandleCreated(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToAnthHandleOutputItemAdded(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToAnthHandleTextDelta(evt, state)
|
||||
case "response.output_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToAnthHandleFuncArgsDelta(evt, state)
|
||||
case "response.function_call_arguments.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.output_item.done":
|
||||
return resToAnthHandleOutputItemDone(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToAnthHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToAnthHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesAnthropicStream emits synthetic termination events if the
|
||||
// stream ended without a proper completion event.
|
||||
func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.MessageStartSent || state.MessageStopSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: "end_turn",
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
CacheReadInputTokens: state.CacheReadInputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicStreamEvent{Type: "message_stop"},
|
||||
)
|
||||
state.MessageStopSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
// ResponsesAnthropicEventToSSE formats an AnthropicStreamEvent as an SSE line pair.
|
||||
func ResponsesAnthropicEventToSSE(evt AnthropicStreamEvent) (string, error) {
|
||||
data, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToAnthHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Response != nil {
|
||||
state.ResponseID = evt.Response.ID
|
||||
// Only use upstream model if no override was set (e.g. originalModel)
|
||||
if state.Model == "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
|
||||
if state.MessageStartSent {
|
||||
return nil
|
||||
}
|
||||
state.MessageStartSent = true
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "message_start",
|
||||
Message: &AnthropicResponse{
|
||||
ID: state.ResponseID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: []AnthropicContentBlock{},
|
||||
Model: state.Model,
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch evt.Item.Type {
|
||||
case "function_call":
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "tool_use"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(evt.Item.CallID),
|
||||
Name: evt.Item.Name,
|
||||
Input: json.RawMessage("{}"),
|
||||
},
|
||||
})
|
||||
return events
|
||||
|
||||
case "reasoning":
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "thinking"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
},
|
||||
})
|
||||
return events
|
||||
|
||||
case "message":
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resToAnthHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
|
||||
if !state.ContentBlockOpen || state.CurrentBlockType != "text" {
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "text"
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: &idx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "text_delta",
|
||||
Text: evt.Delta,
|
||||
},
|
||||
})
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &blockIdx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: evt.Delta,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &blockIdx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: evt.Delta,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleBlockDone(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.ContentBlockOpen {
|
||||
return nil
|
||||
}
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
|
||||
func resToAnthHandleOutputItemDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle web_search_call → synthesize server_tool_use + web_search_tool_result blocks.
|
||||
if evt.Item.Type == "web_search_call" && evt.Item.Status == "completed" {
|
||||
return resToAnthHandleWebSearchDone(evt, state)
|
||||
}
|
||||
|
||||
if state.ContentBlockOpen {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resToAnthHandleWebSearchDone converts an OpenAI web_search_call output item
|
||||
// into Anthropic server_tool_use + web_search_tool_result content block pairs.
|
||||
// This allows Claude Code to count the searches performed.
|
||||
func resToAnthHandleWebSearchDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
toolUseID := "srvtoolu_" + evt.Item.ID
|
||||
query := ""
|
||||
if evt.Item.Action != nil {
|
||||
query = evt.Item.Action.Query
|
||||
}
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
|
||||
// Emit server_tool_use block (start + stop).
|
||||
idx1 := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx1,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: inputJSON,
|
||||
},
|
||||
})
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx1,
|
||||
})
|
||||
state.ContentBlockIndex++
|
||||
|
||||
// Emit web_search_tool_result block (start + stop).
|
||||
// Content is empty because OpenAI does not expose individual search results;
|
||||
// the model consumes them internally and produces text output.
|
||||
emptyResults, _ := json.Marshal([]struct{}{})
|
||||
idx2 := state.ContentBlockIndex
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: &idx2,
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: emptyResults,
|
||||
},
|
||||
})
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx2,
|
||||
})
|
||||
state.ContentBlockIndex++
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if state.MessageStopSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
stopReason := "end_turn"
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
state.InputTokens = evt.Response.Usage.InputTokens
|
||||
state.OutputTokens = evt.Response.Usage.OutputTokens
|
||||
if evt.Response.Usage.InputTokensDetails != nil {
|
||||
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
case "completed":
|
||||
if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
CacheReadInputTokens: state.CacheReadInputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicStreamEvent{Type: "message_stop"},
|
||||
)
|
||||
state.MessageStopSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if !state.ContentBlockOpen {
|
||||
return nil
|
||||
}
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = false
|
||||
state.ContentBlockIndex++
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx,
|
||||
}}
|
||||
}
|
||||
338
backend/internal/pkg/apicompat/types.go
Normal file
338
backend/internal/pkg/apicompat/types.go
Normal file
@@ -0,0 +1,338 @@
|
||||
// Package apicompat provides type definitions and conversion utilities for
|
||||
// translating between Anthropic Messages and OpenAI Responses API formats.
|
||||
// It enables multi-protocol support so that clients using different API
|
||||
// formats can be served through a unified gateway.
|
||||
package apicompat
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic Messages API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicOutputConfig controls output generation parameters.
|
||||
type AnthropicOutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
type AnthropicThinking struct {
|
||||
Type string `json:"type"` // "enabled" | "adaptive" | "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // max thinking tokens
|
||||
}
|
||||
|
||||
// AnthropicMessage is a single message in the Anthropic conversation.
|
||||
type AnthropicMessage struct {
|
||||
Role string `json:"role"` // "user" | "assistant"
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// AnthropicContentBlock is one block inside a message's content array.
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// type=text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=image
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// type=tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []AnthropicContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
}
|
||||
|
||||
// AnthropicResponse is the non-streaming response from POST /v1/messages.
|
||||
type AnthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Content []AnthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// AnthropicUsage holds token counts in Anthropic format.
|
||||
type AnthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic SSE event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicStreamEvent is a single SSE event in the Anthropic streaming protocol.
|
||||
type AnthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// message_start
|
||||
Message *AnthropicResponse `json:"message,omitempty"`
|
||||
|
||||
// content_block_start
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"`
|
||||
|
||||
// content_block_delta
|
||||
Delta *AnthropicDelta `json:"delta,omitempty"`
|
||||
|
||||
// message_delta
|
||||
Usage *AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicDelta carries incremental content in streaming events.
|
||||
type AnthropicDelta struct {
|
||||
Type string `json:"type,omitempty"` // "text_delta" | "input_json_delta" | "thinking_delta" | "signature_delta"
|
||||
|
||||
// text_delta
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// input_json_delta
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
|
||||
// thinking_delta
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// signature_delta
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// message_delta fields
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Responses API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesRequest is the request body for POST /v1/responses.
|
||||
type ResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []ResponsesTool `json:"tools,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
type ResponsesReasoning struct {
|
||||
Effort string `json:"effort"` // "low" | "medium" | "high"
|
||||
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
|
||||
}
|
||||
|
||||
// ResponsesInputItem is one item in the Responses API input array.
|
||||
// The Type field determines which other fields are populated.
|
||||
type ResponsesInputItem struct {
|
||||
// Common
|
||||
Type string `json:"type,omitempty"` // "" for role-based messages
|
||||
|
||||
// Role-based messages (system/user/assistant)
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart
|
||||
|
||||
// type=function_call
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
|
||||
// type=function_call_output
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function" | "web_search" | "local_shell" etc.
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesResponse is the non-streaming response from POST /v1/responses.
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "response"
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"` // "completed" | "incomplete" | "failed"
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
|
||||
// incomplete_details is present when status="incomplete"
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
|
||||
// Error is present when status="failed"
|
||||
Error *ResponsesError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError describes an error in a failed response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails explains why a response is incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"` // "max_output_tokens" | "content_filter"
|
||||
}
|
||||
|
||||
// ResponsesOutput is one output item in a Responses API response.
|
||||
type ResponsesOutput struct {
|
||||
Type string `json:"type"` // "message" | "reasoning" | "function_call" | "web_search_call"
|
||||
|
||||
// type=message
|
||||
ID string `json:"id,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesContentPart `json:"content,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// type=reasoning
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
Summary []ResponsesSummary `json:"summary,omitempty"`
|
||||
|
||||
// type=function_call
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
||||
// type=web_search_call
|
||||
Action *WebSearchAction `json:"action,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchAction describes the search action in a web_search_call output item.
|
||||
type WebSearchAction struct {
|
||||
Type string `json:"type,omitempty"` // "search"
|
||||
Query string `json:"query,omitempty"` // primary search query
|
||||
}
|
||||
|
||||
// ResponsesSummary is a summary text block inside a reasoning output.
|
||||
type ResponsesSummary struct {
|
||||
Type string `json:"type"` // "summary_text"
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ResponsesUsage holds token counts in Responses API format.
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
|
||||
// Optional detailed breakdown
|
||||
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails breaks down input token usage.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputTokensDetails breaks down output token usage.
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Responses SSE event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesStreamEvent is a single SSE event in the Responses streaming protocol.
|
||||
// The Type field corresponds to the "type" in the JSON payload.
|
||||
type ResponsesStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// response.created / response.completed / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
|
||||
// response.output_text.delta / response.output_text.done
|
||||
OutputIndex int `json:"output_index,omitempty"`
|
||||
ContentIndex int `json:"content_index,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ItemID string `json:"item_id,omitempty"`
|
||||
|
||||
// response.function_call_arguments.delta / done
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
||||
// response.reasoning_summary_text.delta / done
|
||||
// Reuses Text/Delta fields above, SummaryIndex identifies which summary part
|
||||
SummaryIndex int `json:"summary_index,omitempty"`
|
||||
|
||||
// error event fields
|
||||
Code string `json:"code,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
|
||||
// Sequence number for ordering events
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// minMaxOutputTokens is the floor for max_output_tokens in a Responses request.
|
||||
// Very small values may cause upstream API errors, so we enforce a minimum.
|
||||
const minMaxOutputTokens = 128
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
var DroppedBetas = []string{BetaFastMode}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
|
||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
|
||||
// official Codex client family request.
|
||||
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
|
||||
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
|
||||
}
|
||||
|
||||
func normalizeCodexClientHeader(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
originator string
|
||||
want bool
|
||||
}{
|
||||
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
|
||||
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
|
||||
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
|
||||
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -659,13 +659,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -925,6 +922,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1040,6 +1038,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1676,13 +1675,47 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
-- 总额度:始终递增
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
-- 日额度:仅在 quota_daily_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
@@ -1704,7 +1737,7 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
// 任一维度配额刚超限时触发调度快照刷新
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
@@ -1713,14 +1746,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -558,6 +558,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
|
||||
@@ -24,6 +24,7 @@ func (r *announcementRepository) Create(ctx context.Context, a *service.Announce
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
@@ -64,6 +65,7 @@ func (r *announcementRepository) Update(ctx context.Context, a *service.Announce
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
@@ -169,17 +171,18 @@ func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
NotifyMode: m.NotifyMode,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -165,6 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -470,12 +472,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
|
||||
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = usage_5h + $1,
|
||||
usage_1d = usage_1d + $1,
|
||||
usage_7d = usage_7d + $1,
|
||||
window_5h_start = COALESCE(window_5h_start, NOW()),
|
||||
window_1d_start = COALESCE(window_1d_start, NOW()),
|
||||
window_7d_start = COALESCE(window_7d_start, NOW()),
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -489,9 +491,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
@@ -619,6 +621,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
|
||||
@@ -59,7 +59,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -125,7 +127,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ type scannable interface {
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
|
||||
simpleModeLegacyAdminConcurrency = 5
|
||||
simpleModeTargetAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
if upgraded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.User.Update().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
|
||||
).
|
||||
SetConcurrency(simpleModeTargetAdminConcurrency).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if err := client.Setting.Create().
|
||||
SetKey(simpleModeAdminConcurrencyUpgradeKey).
|
||||
SetValue(now.Format(time.RFC3339)).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx); err != nil {
|
||||
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
@@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -2505,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
@@ -2544,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
@@ -2614,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if serviceTier.Valid {
|
||||
log.ServiceTier = &serviceTier.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Nil(t, log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
@@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "flex", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("service_tier_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(3),
|
||||
int64(12),
|
||||
int64(22),
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"allow_messages_dispatch": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
|
||||
@@ -244,6 +244,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
@@ -392,6 +393,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
|
||||
@@ -43,12 +43,33 @@ func RegisterGatewayRoutes(
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
// /v1/messages: auto-route based on group platform
|
||||
gateway.POST("/messages", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Messages(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Messages(c)
|
||||
})
|
||||
// /v1/messages/count_tokens: OpenAI groups get 404
|
||||
gateway.POST("/messages/count_tokens", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "not_found_error",
|
||||
"message": "Token counting is not supported for this platform",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
h.Gateway.CountTokens(c)
|
||||
})
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
@@ -77,6 +98,7 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
|
||||
// Antigravity 模型列表
|
||||
@@ -132,3 +154,12 @@ func RegisterGatewayRoutes(
|
||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
||||
}
|
||||
|
||||
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
||||
func getGroupPlatform(c *gin.Context) string {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return apiKey.Group.Platform
|
||||
}
|
||||
|
||||
51
backend/internal/server/routes/gateway_test.go
Normal file
51
backend/internal/server/routes/gateway_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRoutesTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
RegisterGatewayRoutes(
|
||||
router,
|
||||
&handler.Handlers{
|
||||
Gateway: &handler.GatewayHandler{},
|
||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||
SoraGateway: &handler.SoraGatewayHandler{},
|
||||
},
|
||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{},
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
|
||||
router := newGatewayRoutesTestRouter()
|
||||
|
||||
for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
|
||||
}
|
||||
}
|
||||
@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPoolModeRetryCount = 3
|
||||
maxPoolModeRetryCount = 10
|
||||
)
|
||||
|
||||
// GetPoolModeRetryCount 返回池模式同账号重试次数。
|
||||
// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。
|
||||
func (a *Account) GetPoolModeRetryCount() int {
|
||||
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
raw, ok := a.Credentials["pool_mode_retry_count"]
|
||||
if !ok || raw == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
count := parsePoolModeRetryCount(raw)
|
||||
if count < 0 {
|
||||
return 0
|
||||
}
|
||||
if count > maxPoolModeRetryCount {
|
||||
return maxPoolModeRetryCount
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePoolModeRetryCount(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
|
||||
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
|
||||
func isPoolModeRetryableStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -1134,33 +1203,97 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
return a.getExtraFloat64("quota_limit")
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
return a.getExtraFloat64("quota_used")
|
||||
}
|
||||
|
||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_daily_limit")
|
||||
}
|
||||
|
||||
// GetQuotaDailyUsed 获取当日已用额度(美元)
|
||||
func (a *Account) GetQuotaDailyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_daily_used")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaWeeklyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_limit")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
|
||||
func (a *Account) GetQuotaWeeklyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_used")
|
||||
}
|
||||
|
||||
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
|
||||
func (a *Account) getExtraFloat64(key string) float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
|
||||
func (a *Account) getExtraTime(key string) time.Time {
|
||||
if a.Extra == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
|
||||
return t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
}
|
||||
|
||||
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期
|
||||
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true // 从未使用过,视为过期(下次 increment 会初始化)
|
||||
}
|
||||
return time.Since(periodStart) >= dur
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
// 总额度
|
||||
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
|
||||
117
backend/internal/service/account_pool_mode_test.go
Normal file
117
backend/internal/service/account_pool_mode_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPoolModeRetryCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "default_when_not_pool_mode",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "default_when_missing_retry_count",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "supports_float64_from_json_credentials",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": float64(5),
|
||||
},
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "supports_json_number",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": json.Number("4"),
|
||||
},
|
||||
},
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "supports_string_value",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "2",
|
||||
},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "negative_value_is_clamped_to_zero",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": -1,
|
||||
},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "oversized_value_is_clamped_to_max",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": 99,
|
||||
},
|
||||
},
|
||||
expected: maxPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "invalid_value_falls_back_to_default",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "oops",
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -68,9 +68,9 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
|
||||
@@ -406,8 +406,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
|
||||
102
backend/internal/service/account_test_service_openai_test.go
Normal file
102
backend/internal/service/account_test_service_openai_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
r.rateLimitedAt = &resetAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
|
||||
`))
|
||||
resp.Header.Set("x-codex-primary-used-percent", "88")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "42")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 89,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,24 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
httppool "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
@@ -70,8 +77,10 @@ type accountWindowStatsBatchReader interface {
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
// 同时支持缓存错误响应(负缓存),防止 429 等错误导致的重试风暴
|
||||
type apiUsageCache struct {
|
||||
response *ClaudeUsageResponse
|
||||
err error // 非 nil 表示缓存的错误(负缓存)
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
@@ -88,15 +97,21 @@ type antigravityUsageCache struct {
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
openAICodexProbeVersion = "0.104.0"
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -224,6 +239,14 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth {
|
||||
usage, err := s.getOpenAIUsage(ctx, account)
|
||||
if err == nil {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
|
||||
if account.Platform == PlatformGemini {
|
||||
usage, err := s.getGeminiUsage(ctx, account)
|
||||
if err == nil {
|
||||
@@ -245,24 +268,65 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
if account.CanGetUsage() {
|
||||
var apiResp *ClaudeUsageResponse
|
||||
|
||||
// 1. 检查 API 缓存(10 分钟)
|
||||
// 1. 检查缓存(成功响应 3 分钟 / 错误响应 1 分钟)
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
if cache, ok := cached.(*apiUsageCache); ok {
|
||||
age := time.Since(cache.timestamp)
|
||||
if cache.err != nil && age < apiErrorCacheTTL {
|
||||
// 负缓存命中:返回缓存的错误,避免重试风暴
|
||||
return nil, cache.err
|
||||
}
|
||||
if cache.response != nil && age < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果没有缓存,从 API 获取
|
||||
// 2. 如果没有有效缓存,通过 singleflight 从 API 获取(防止并发击穿)
|
||||
if apiResp == nil {
|
||||
apiResp, err = s.fetchOAuthUsageRaw(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 随机延迟:打散多账号并发请求,避免同一时刻大量相同 TLS 指纹请求
|
||||
// 触发上游反滥用检测。延迟范围 0~800ms,仅在缓存未命中时生效。
|
||||
jitter := time.Duration(rand.Int64N(int64(apiQueryMaxJitter)))
|
||||
select {
|
||||
case <-time.After(jitter):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
// 缓存 API 响应
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: apiResp,
|
||||
timestamp: time.Now(),
|
||||
|
||||
flightKey := fmt.Sprintf("usage:%d", accountID)
|
||||
result, flightErr, _ := s.cache.apiFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(可能在等待 singleflight 期间被其他请求填充)
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok {
|
||||
age := time.Since(cache.timestamp)
|
||||
if cache.err != nil && age < apiErrorCacheTTL {
|
||||
return nil, cache.err
|
||||
}
|
||||
if cache.response != nil && age < apiCacheTTL {
|
||||
return cache.response, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
resp, fetchErr := s.fetchOAuthUsageRaw(ctx, account)
|
||||
if fetchErr != nil {
|
||||
// 负缓存:缓存错误响应,防止后续请求重复触发 429
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
err: fetchErr,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return nil, fetchErr
|
||||
}
|
||||
// 缓存成功响应
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: resp,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return resp, nil
|
||||
})
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
apiResp, _ = result.(*ClaudeUsageResponse)
|
||||
}
|
||||
|
||||
// 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds)
|
||||
@@ -288,6 +352,210 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{UpdatedAt: &now}
|
||||
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
mergeAccountExtra(account, updates)
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.usageLogRepo == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.FiveHour == nil {
|
||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.FiveHour.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.SevenDay == nil {
|
||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.SevenDay.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if usage == nil {
|
||||
return true
|
||||
}
|
||||
if usage.FiveHour == nil || usage.SevenDay == nil {
|
||||
return true
|
||||
}
|
||||
if account.IsRateLimited() {
|
||||
return true
|
||||
}
|
||||
return isOpenAICodexSnapshotStale(account, now)
|
||||
}
|
||||
|
||||
func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
|
||||
if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return false
|
||||
}
|
||||
if account.Extra == nil {
|
||||
return true
|
||||
}
|
||||
raw, ok := account.Extra["codex_usage_updated_at"]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ts, err := parseTime(fmt.Sprint(raw))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return now.Sub(ts) >= openAIProbeCacheTTL
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
}
|
||||
if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok {
|
||||
if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.cache.openAIProbeCache.Store(accountID, now)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
if account == nil || !account.IsOAuth() {
|
||||
return nil, nil
|
||||
}
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
modelID := openaipkg.DefaultTestModel
|
||||
payload := createOpenAITestPayload(modelID, true)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
}
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
req.Header.Set("Originator", "codex_cli_rs")
|
||||
req.Header.Set("Version", openAICodexProbeVersion)
|
||||
req.Header.Set("User-Agent", codexCLIUserAgent)
|
||||
if s.identityCache != nil {
|
||||
if fp, fpErr := s.identityCache.GetFingerprint(reqCtx, account.ID); fpErr == nil && fp != nil && strings.TrimSpace(fp.UserAgent) != "" {
|
||||
req.Header.Set("User-Agent", strings.TrimSpace(fp.UserAgent))
|
||||
}
|
||||
}
|
||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
client, err := httppool.GetClient(httppool.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||
if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||
if account == nil || len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any, len(updates))
|
||||
}
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
@@ -519,6 +787,72 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
||||
}
|
||||
}
|
||||
|
||||
func hasMeaningfulWindowStats(stats *WindowStats) bool {
|
||||
if stats == nil {
|
||||
return false
|
||||
}
|
||||
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
|
||||
}
|
||||
|
||||
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
usedPercentKey string
|
||||
resetAfterKey string
|
||||
resetAtKey string
|
||||
)
|
||||
|
||||
switch window {
|
||||
case "5h":
|
||||
usedPercentKey = "codex_5h_used_percent"
|
||||
resetAfterKey = "codex_5h_reset_after_seconds"
|
||||
resetAtKey = "codex_5h_reset_at"
|
||||
case "7d":
|
||||
usedPercentKey = "codex_7d_used_percent"
|
||||
resetAfterKey = "codex_7d_reset_after_seconds"
|
||||
resetAtKey = "codex_7d_reset_at"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
usedRaw, ok := extra[usedPercentKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
progress := &UsageProgress{Utilization: parseExtraFloat64(usedRaw)}
|
||||
if resetAtRaw, ok := extra[resetAtKey]; ok {
|
||||
if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil {
|
||||
progress.ResetsAt = &resetAt
|
||||
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
|
||||
if progress.RemainingSeconds < 0 {
|
||||
progress.RemainingSeconds = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if progress.ResetsAt == nil {
|
||||
if resetAfterSeconds := parseExtraInt(extra[resetAfterKey]); resetAfterSeconds > 0 {
|
||||
base := now
|
||||
if updatedAtRaw, ok := extra["codex_usage_updated_at"]; ok {
|
||||
if updatedAt, err := parseTime(fmt.Sprint(updatedAtRaw)); err == nil {
|
||||
base = updatedAt
|
||||
}
|
||||
}
|
||||
resetAt := base.Add(time.Duration(resetAfterSeconds) * time.Second)
|
||||
progress.ResetsAt = &resetAt
|
||||
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
|
||||
if progress.RemainingSeconds < 0 {
|
||||
progress.RemainingSeconds = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return progress
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
@@ -666,15 +1000,30 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||
// 优先使用响应头中存储的真实 utilization 值(0-1 小数,转为 0-100 百分比)
|
||||
var utilization float64
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
default:
|
||||
utilization = 0.0
|
||||
var found bool
|
||||
if stored, ok := account.Extra["session_window_utilization"]; ok {
|
||||
switch v := stored.(type) {
|
||||
case float64:
|
||||
utilization = v * 100
|
||||
found = true
|
||||
case json.Number:
|
||||
if f, err := v.Float64(); err == nil {
|
||||
utilization = f * 100
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有存储的 utilization,回退到状态估算
|
||||
if !found {
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
}
|
||||
}
|
||||
|
||||
info.FiveHour = &UsageProgress{
|
||||
|
||||
68
backend/internal/service/account_usage_service_test.go
Normal file
68
backend/internal/service/account_usage_service_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(5 * time.Minute)
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
FiveHour: &UsageProgress{Utilization: 0},
|
||||
SevenDay: &UsageProgress{Utilization: 0},
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
|
||||
t.Fatal("expected rate-limited account to force codex snapshot refresh")
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
|
||||
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
|
||||
t.Fatal("expected missing 5h snapshot to require refresh")
|
||||
}
|
||||
|
||||
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"codex_usage_updated_at": staleAt,
|
||||
},
|
||||
}, usage, now) {
|
||||
t.Fatal("expected stale ws snapshot to trigger refresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if got := updates["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,9 @@ type CreateGroupInput struct {
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
DefaultMappedModel string
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -181,6 +184,9 @@ type UpdateGroupInput struct {
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool
|
||||
DefaultMappedModel *string
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -909,6 +915,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||
DefaultMappedModel: input.DefaultMappedModel,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -1122,6 +1130,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.SupportedModelScopes = *input.SupportedModelScopes
|
||||
}
|
||||
|
||||
// OpenAI Messages 调度配置
|
||||
if input.AllowMessagesDispatch != nil {
|
||||
group.AllowMessagesDispatch = *input.AllowMessagesDispatch
|
||||
}
|
||||
if input.DefaultMappedModel != nil {
|
||||
group.DefaultMappedModel = *input.DefaultMappedModel
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1333,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range accounts {
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
|
||||
}
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
@@ -1468,9 +1488,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// 保留 quota_used,防止编辑账号时意外重置配额用量
|
||||
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
|
||||
input.Extra["quota_used"] = oldQuotaUsed
|
||||
// 保留配额用量字段,防止编辑账号时意外重置
|
||||
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
||||
if v, ok := account.Extra[key]; ok {
|
||||
input.Extra[key] = v
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
@@ -1701,16 +1723,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err := s.accountRepo.ClearError(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
|
||||
@@ -14,6 +14,11 @@ const (
|
||||
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent
|
||||
AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
|
||||
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
|
||||
|
||||
@@ -33,23 +33,25 @@ func NewAnnouncementService(
|
||||
}
|
||||
|
||||
type CreateAnnouncementInput struct {
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UpdateAnnouncementInput struct {
|
||||
Title *string
|
||||
Content *string
|
||||
Status *string
|
||||
Targeting *AnnouncementTargeting
|
||||
StartsAt **time.Time
|
||||
EndsAt **time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
Title *string
|
||||
Content *string
|
||||
Status *string
|
||||
NotifyMode *string
|
||||
Targeting *AnnouncementTargeting
|
||||
StartsAt **time.Time
|
||||
EndsAt **time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
@@ -93,6 +95,14 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
|
||||
return nil, err
|
||||
}
|
||||
|
||||
notifyMode := strings.TrimSpace(input.NotifyMode)
|
||||
if notifyMode == "" {
|
||||
notifyMode = AnnouncementNotifyModeSilent
|
||||
}
|
||||
if !isValidAnnouncementNotifyMode(notifyMode) {
|
||||
return nil, fmt.Errorf("create announcement: invalid notify_mode")
|
||||
}
|
||||
|
||||
if input.StartsAt != nil && input.EndsAt != nil {
|
||||
if !input.StartsAt.Before(*input.EndsAt) {
|
||||
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
|
||||
@@ -100,12 +110,13 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
|
||||
}
|
||||
|
||||
a := &Announcement{
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: status,
|
||||
Targeting: targeting,
|
||||
StartsAt: input.StartsAt,
|
||||
EndsAt: input.EndsAt,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: status,
|
||||
NotifyMode: notifyMode,
|
||||
Targeting: targeting,
|
||||
StartsAt: input.StartsAt,
|
||||
EndsAt: input.EndsAt,
|
||||
}
|
||||
if input.ActorID != nil && *input.ActorID > 0 {
|
||||
a.CreatedBy = input.ActorID
|
||||
@@ -150,6 +161,14 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
|
||||
a.Status = status
|
||||
}
|
||||
|
||||
if input.NotifyMode != nil {
|
||||
notifyMode := strings.TrimSpace(*input.NotifyMode)
|
||||
if !isValidAnnouncementNotifyMode(notifyMode) {
|
||||
return nil, fmt.Errorf("update announcement: invalid notify_mode")
|
||||
}
|
||||
a.NotifyMode = notifyMode
|
||||
}
|
||||
|
||||
if input.Targeting != nil {
|
||||
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
|
||||
if err != nil {
|
||||
@@ -376,3 +395,12 @@ func isValidAnnouncementStatus(status string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isValidAnnouncementNotifyMode(mode string) bool {
|
||||
switch mode {
|
||||
case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1384,7 +1384,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody, maxBytes := s.getLogConfig()
|
||||
@@ -1517,6 +1517,80 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Budget 整流:检测 budget_tokens 约束错误并自动修正重试
|
||||
if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) {
|
||||
errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: s.getUpstreamErrorDetail(respBody),
|
||||
})
|
||||
|
||||
// 修正 claudeReq 的 thinking 参数(adaptive 模式不修正)
|
||||
if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
// 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针
|
||||
retryClaudeReq.Thinking = &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: BudgetRectifyBudgetTokens,
|
||||
}
|
||||
if retryClaudeReq.MaxTokens < BudgetRectifyMinMaxTokens {
|
||||
retryClaudeReq.MaxTokens = BudgetRectifyMaxTokens
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts)
|
||||
if txErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
account: account,
|
||||
proxyURL: proxyURL,
|
||||
accessToken: accessToken,
|
||||
action: action,
|
||||
body: retryGeminiBody,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
accountRepo: s.accountRepo,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
})
|
||||
if retryErr == nil {
|
||||
retryResp := retryResult.resp
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
} else {
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: budget rectifier retry failed: %v", account.ID, retryErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
|
||||
@@ -3696,6 +3770,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
cw.Write(finalEvents)
|
||||
} else if !processor.MessageStartSent() && !cw.Disconnected() {
|
||||
// 整个流未收到任何可解析的上游数据(全部 SSE 行均无法被 JSON 解析),
|
||||
// 触发 failover 在同账号重试,避免向客户端发出缺少 message_start 的残缺流
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Claude-Stream] empty stream response (no valid events parsed), triggering failover")
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
|
||||
RetryableOnSameAccount: true,
|
||||
}
|
||||
}
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||||
}
|
||||
|
||||
@@ -998,6 +998,46 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_EmptyStream
|
||||
// 验证:上游只返回无法解析的 SSE 行时,触发 UpstreamFailoverError 而不是向客户端发出残缺流
|
||||
func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 所有行均为无法 JSON 解析的内容,ProcessLine 全部返回 nil
|
||||
fmt.Fprintln(pw, "data: not-valid-json")
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, "data: also-invalid")
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_ = pr.Close()
|
||||
|
||||
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.True(t, failoverErr.RetryableOnSameAccount)
|
||||
|
||||
// 客户端不应收到任何 SSE 事件(既无 message_start 也无 message_stop)
|
||||
body := rec.Body.String()
|
||||
require.NotContains(t, body, "event: message_start")
|
||||
require.NotContains(t, body, "event: message_stop")
|
||||
require.NotContains(t, body, "event: message_delta")
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||||
// 验证:context 取消时不注入错误事件
|
||||
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
|
||||
@@ -14,6 +14,18 @@ const (
|
||||
StatusAPIKeyExpired = "expired"
|
||||
)
|
||||
|
||||
// Rate limit window durations
|
||||
const (
|
||||
RateLimitWindow5h = 5 * time.Hour
|
||||
RateLimitWindow1d = 24 * time.Hour
|
||||
RateLimitWindow7d = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||
return windowStart != nil && time.Since(*windowStart) >= duration
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
@@ -98,6 +110,30 @@ func (k *APIKey) GetDaysUntilExpiry() int {
|
||||
return int(duration.Hours() / 24)
|
||||
}
|
||||
|
||||
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage5h() float64 {
|
||||
if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage5h
|
||||
}
|
||||
|
||||
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage1d() float64 {
|
||||
if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage1d
|
||||
}
|
||||
|
||||
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage7d() float64 {
|
||||
if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyListFilters holds optional filtering parameters for listing API keys.
|
||||
type APIKeyListFilters struct {
|
||||
Search string
|
||||
|
||||
@@ -65,6 +65,10 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -245,6 +245,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -302,6 +304,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
245
backend/internal/service/api_key_rate_limit_test.go
Normal file
245
backend/internal/service/api_key_rate_limit_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsWindowExpired(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
start *time.Time
|
||||
duration time.Duration
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil window start",
|
||||
start: nil,
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "active window (started 1h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired window (started 6h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exactly at boundary (started 5h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-5 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active 1d window (started 12h ago)",
|
||||
start: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
duration: RateLimitWindow1d,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired 1d window (started 25h ago)",
|
||||
start: rateLimitTimePtr(now.Add(-25 * time.Hour)),
|
||||
duration: RateLimitWindow1d,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active 7d window (started 3d ago)",
|
||||
start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
|
||||
duration: RateLimitWindow7d,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired 7d window (started 8d ago)",
|
||||
start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
|
||||
duration: RateLimitWindow7d,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsWindowExpired(tt.start, tt.duration)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "all windows expired",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: nil,
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "zero usage with active windows",
|
||||
key: APIKey{
|
||||
Usage5h: 0,
|
||||
Usage1d: 0,
|
||||
Usage7d: 0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.key.EffectiveUsage5h(); got != tt.want5h {
|
||||
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
|
||||
}
|
||||
if got := tt.key.EffectiveUsage1d(); got != tt.want1d {
|
||||
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
|
||||
}
|
||||
if got := tt.key.EffectiveUsage7d(); got != tt.want7d {
|
||||
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data APIKeyRateLimitData
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
},
|
||||
{
|
||||
name: "all windows expired",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: nil,
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.data.EffectiveUsage5h(); got != tt.want5h {
|
||||
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
|
||||
}
|
||||
if got := tt.data.EffectiveUsage1d(); got != tt.want1d {
|
||||
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
|
||||
}
|
||||
if got := tt.data.EffectiveUsage7d(); got != tt.want7d {
|
||||
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func rateLimitTimePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
@@ -86,6 +86,30 @@ type APIKeyRateLimitData struct {
|
||||
Window7dStart *time.Time
|
||||
}
|
||||
|
||||
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 {
|
||||
if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage5h
|
||||
}
|
||||
|
||||
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 {
|
||||
if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage1d
|
||||
}
|
||||
|
||||
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
||||
if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
@@ -565,15 +565,15 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
|
||||
needsReset := false
|
||||
|
||||
// Reset expired windows in-memory for check purposes
|
||||
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
|
||||
if IsWindowExpired(w5h, RateLimitWindow5h) {
|
||||
usage5h = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
|
||||
if IsWindowExpired(w1d, RateLimitWindow1d) {
|
||||
usage1d = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
|
||||
if IsWindowExpired(w7d, RateLimitWindow7d) {
|
||||
usage7d = 0
|
||||
needsReset = true
|
||||
}
|
||||
@@ -589,12 +589,16 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
|
||||
if loader, ok := s.apiKeyRateLimitLoader.(interface {
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
}); ok {
|
||||
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
|
||||
if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invalidate cache so next request loads fresh data
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
|
||||
if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -43,16 +43,19 @@ type BillingCache interface {
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -61,6 +64,28 @@ const (
|
||||
openAIGPT54LongContextOutputMultiplier = 1.5
|
||||
)
|
||||
|
||||
func normalizeBillingServiceTier(serviceTier string) string {
|
||||
return strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
}
|
||||
|
||||
func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool {
|
||||
if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" {
|
||||
return false
|
||||
}
|
||||
return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0
|
||||
}
|
||||
|
||||
func serviceTierCostMultiplier(serviceTier string) float64 {
|
||||
switch normalizeBillingServiceTier(serviceTier) {
|
||||
case "priority":
|
||||
return 2.0
|
||||
case "flex":
|
||||
return 0.5
|
||||
default:
|
||||
return 1.0
|
||||
}
|
||||
}
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
@@ -173,30 +198,60 @@ func (s *BillingService) initFallbackPricing() {
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
CacheReadPricePerTokenPriority: 0.25e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
InputPricePerTokenPriority: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerTokenPriority: 30e-6, // $30 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
InputPricePerTokenPriority: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
OutputPricePerTokenPriority: 24e-6, // $24 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
CacheReadPricePerTokenPriority: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
@@ -241,6 +296,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
switch normalized {
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
return s.fallbackPrices["gpt-5.2"]
|
||||
case "gpt-5.2-codex":
|
||||
return s.fallbackPrices["gpt-5.2-codex"]
|
||||
case "gpt-5.3-codex":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
@@ -269,16 +328,19 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
@@ -295,6 +357,10 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -303,6 +369,21 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
if usePriorityServiceTierPricing(serviceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(serviceTier)
|
||||
}
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
@@ -329,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
|
||||
@@ -522,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) {
|
||||
require.False(t, math.IsNaN(cost.TotalCost))
|
||||
require.False(t, math.IsInf(cost.TotalCost, 0))
|
||||
}
|
||||
|
||||
func TestServiceTierCostMultiplier(t *testing.T) {
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12)
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12)
|
||||
require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8}
|
||||
|
||||
baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) {
|
||||
pricingSvc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-5.4": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
InputCostPerTokenPriority: 5e-6,
|
||||
OutputCostPerToken: 15e-6,
|
||||
OutputCostPerTokenPriority: 30e-6,
|
||||
CacheCreationInputTokenCost: 2.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
CacheReadInputTokenCostPriority: 0.5e-6,
|
||||
LongContextInputTokenThreshold: 272000,
|
||||
LongContextInputCostMultiplier: 2.0,
|
||||
LongContextOutputCostMultiplier: 1.5,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewBillingService(&config.Config{}, pricingSvc)
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52Codex)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"custom-no-priority": {
|
||||
InputCostPerToken: 1e-6,
|
||||
OutputCostPerToken: 2e-6,
|
||||
CacheCreationInputTokenCost: 0.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"dynamic-tier-model": {
|
||||
InputCostPerToken: 1e-6,
|
||||
InputCostPerTokenPriority: 2e-6,
|
||||
OutputCostPerToken: 3e-6,
|
||||
OutputCostPerTokenPriority: 6e-6,
|
||||
CacheCreationInputTokenCost: 4e-6,
|
||||
CacheCreationInputTokenCostAbove1hr: 5e-6,
|
||||
CacheReadInputTokenCost: 7e-7,
|
||||
CacheReadInputTokenCostPriority: 8e-7,
|
||||
LongContextInputTokenThreshold: 999,
|
||||
LongContextInputCostMultiplier: 1.5,
|
||||
LongContextOutputCostMultiplier: 1.25,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pricing, err := svc.GetModelPricing("dynamic-tier-model")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
|
||||
require.True(t, pricing.SupportsCacheBreakdown)
|
||||
require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 999, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -175,6 +175,13 @@ const (
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
|
||||
// =========================
|
||||
// Request Rectifier (请求整流器)
|
||||
// =========================
|
||||
|
||||
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
|
||||
SettingKeyRectifierSettings = "rectifier_settings"
|
||||
|
||||
// =========================
|
||||
// Sora S3 存储配置
|
||||
// =========================
|
||||
|
||||
@@ -177,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
{
|
||||
name: "pool_mode_custom_error_codes_hit_returns_matched",
|
||||
account: &Account{
|
||||
ID: 7,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(401), float64(403)},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "pool_mode_without_custom_error_codes_returns_skipped",
|
||||
account: &Account{
|
||||
ID: 8,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -190,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
|
||||
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 30,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.False(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
})
|
||||
|
||||
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 31,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(401)},
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "DroppedBetas removes both context-1m and fast-mode",
|
||||
name: "DroppedBetas removes fast-mode only",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||
tokens: claude.DroppedBetas,
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
want: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -117,21 +117,21 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
||||
drop := droppedBetaSet()
|
||||
|
||||
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,foo-beta", got)
|
||||
require.Contains(t, got, "context-1m-2025-08-07")
|
||||
require.NotContains(t, got, "fast-mode-2026-02-01")
|
||||
}
|
||||
|
||||
func TestDroppedBetaSet(t *testing.T) {
|
||||
// Base set contains DroppedBetas
|
||||
base := droppedBetaSet()
|
||||
require.Contains(t, base, claude.BetaContext1M)
|
||||
require.NotContains(t, base, claude.BetaContext1M)
|
||||
require.Contains(t, base, claude.BetaFastMode)
|
||||
require.Len(t, base, len(claude.DroppedBetas))
|
||||
|
||||
// With extra tokens
|
||||
extended := droppedBetaSet(claude.BetaClaudeCode)
|
||||
require.Contains(t, extended, claude.BetaContext1M)
|
||||
require.NotContains(t, extended, claude.BetaContext1M)
|
||||
require.Contains(t, extended, claude.BetaFastMode)
|
||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||
@@ -148,6 +148,32 @@ func TestBuildBetaTokenSet(t *testing.T) {
|
||||
require.Empty(t, empty)
|
||||
}
|
||||
|
||||
func TestContainsBetaToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
token string
|
||||
want bool
|
||||
}{
|
||||
{"present in middle", "oauth-2025-04-20,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
|
||||
{"present at start", "fast-mode-2026-02-01,oauth-2025-04-20", "fast-mode-2026-02-01", true},
|
||||
{"present at end", "oauth-2025-04-20,fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
|
||||
{"only token", "fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
|
||||
{"not present", "oauth-2025-04-20,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", false},
|
||||
{"with spaces", "oauth-2025-04-20, fast-mode-2026-02-01 , interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
|
||||
{"empty header", "", "fast-mode-2026-02-01", false},
|
||||
{"empty token", "fast-mode-2026-02-01", "", false},
|
||||
{"partial match", "fast-mode-2026-02-01-extra", "fast-mode-2026-02-01", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := containsBetaToken(tt.header, tt.token)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
|
||||
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||
got := stripBetaTokensWithSet(header, map[string]struct{}{})
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@@ -258,6 +259,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
if !hasEmptyContent && !containsThinkingBlocks {
|
||||
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
|
||||
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
|
||||
out = removeThinkingDependentContextStrategies(out)
|
||||
return out
|
||||
}
|
||||
return body
|
||||
@@ -395,6 +397,10 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
} else {
|
||||
return body
|
||||
}
|
||||
// Removing "thinking" makes any context_management strategy that requires it invalid
|
||||
// (e.g. clear_thinking_20251015). Strip those entries so the retry request does not
|
||||
// receive a 400 "strategy requires thinking to be enabled or adaptive".
|
||||
out = removeThinkingDependentContextStrategies(out)
|
||||
}
|
||||
if modified {
|
||||
msgsBytes, err := json.Marshal(messages)
|
||||
@@ -409,6 +415,49 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// removeThinkingDependentContextStrategies 从 context_management.edits 中移除
|
||||
// 需要 thinking 启用的策略(如 clear_thinking_20251015)。
|
||||
// 当顶层 "thinking" 字段被禁用时必须调用,否则上游会返回
|
||||
// "strategy requires thinking to be enabled or adaptive"。
|
||||
func removeThinkingDependentContextStrategies(body []byte) []byte {
|
||||
jsonStr := *(*string)(unsafe.Pointer(&body))
|
||||
editsRes := gjson.Get(jsonStr, "context_management.edits")
|
||||
if !editsRes.Exists() || !editsRes.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
var filtered []json.RawMessage
|
||||
hasRemoved := false
|
||||
editsRes.ForEach(func(_, v gjson.Result) bool {
|
||||
if v.Get("type").String() == "clear_thinking_20251015" {
|
||||
hasRemoved = true
|
||||
return true
|
||||
}
|
||||
filtered = append(filtered, json.RawMessage(v.Raw))
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasRemoved {
|
||||
return body
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
if b, err := sjson.DeleteBytes(body, "context_management.edits"); err == nil {
|
||||
return b
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
filteredBytes, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
if b, err := sjson.SetRawBytes(body, "context_management.edits", filteredBytes); err == nil {
|
||||
return b
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
||||
// signature/thought_signature validation issues involving tool blocks.
|
||||
//
|
||||
@@ -444,6 +493,28 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
// Remove context_management strategies that require thinking to be enabled
|
||||
// (e.g. clear_thinking_20251015), otherwise upstream returns 400.
|
||||
if cm, ok := req["context_management"].(map[string]any); ok {
|
||||
if edits, ok := cm["edits"].([]any); ok {
|
||||
filtered := make([]any, 0, len(edits))
|
||||
for _, edit := range edits {
|
||||
if editMap, ok := edit.(map[string]any); ok {
|
||||
if editMap["type"] == "clear_thinking_20251015" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, edit)
|
||||
}
|
||||
if len(filtered) != len(edits) {
|
||||
if len(filtered) == 0 {
|
||||
delete(cm, "edits")
|
||||
} else {
|
||||
cm["edits"] = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
@@ -675,3 +746,90 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Thinking Budget Rectifier
|
||||
// =========================
|
||||
|
||||
const (
|
||||
// BudgetRectifyBudgetTokens is the budget_tokens value to set when rectifying.
|
||||
BudgetRectifyBudgetTokens = 32000
|
||||
// BudgetRectifyMaxTokens is the max_tokens value to set when rectifying.
|
||||
BudgetRectifyMaxTokens = 64000
|
||||
// BudgetRectifyMinMaxTokens is the minimum max_tokens that must exceed budget_tokens.
|
||||
BudgetRectifyMinMaxTokens = 32001
|
||||
)
|
||||
|
||||
// isThinkingBudgetConstraintError detects whether an upstream error message indicates
|
||||
// a budget_tokens constraint violation (e.g. "budget_tokens >= 1024").
|
||||
// Matches three conditions (all must be true):
|
||||
// 1. Contains "budget_tokens" or "budget tokens"
|
||||
// 2. Contains "thinking"
|
||||
// 3. Contains ">= 1024" or "greater than or equal to 1024" or ("1024" + "input should be")
|
||||
func isThinkingBudgetConstraintError(errMsg string) bool {
|
||||
m := strings.ToLower(errMsg)
|
||||
|
||||
// Condition 1: budget_tokens or budget tokens
|
||||
hasBudget := strings.Contains(m, "budget_tokens") || strings.Contains(m, "budget tokens")
|
||||
if !hasBudget {
|
||||
return false
|
||||
}
|
||||
|
||||
// Condition 2: thinking
|
||||
if !strings.Contains(m, "thinking") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Condition 3: constraint indicator
|
||||
if strings.Contains(m, ">= 1024") || strings.Contains(m, "greater than or equal to 1024") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(m, "1024") && strings.Contains(m, "input should be") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RectifyThinkingBudget modifies the request body to fix budget_tokens constraint errors.
|
||||
// It sets thinking.budget_tokens = 32000, thinking.type = "enabled" (unless adaptive),
|
||||
// and ensures max_tokens >= 32001.
|
||||
// Returns (modified body, true) if changes were applied, or (original body, false) if not.
|
||||
func RectifyThinkingBudget(body []byte) ([]byte, bool) {
|
||||
// If thinking type is "adaptive", skip rectification entirely
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if thinkingType == "adaptive" {
|
||||
return body, false
|
||||
}
|
||||
|
||||
modified := body
|
||||
changed := false
|
||||
|
||||
// Set thinking.type = "enabled"
|
||||
if thinkingType != "enabled" {
|
||||
if result, err := sjson.SetBytes(modified, "thinking.type", "enabled"); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set thinking.budget_tokens = 32000
|
||||
currentBudget := gjson.GetBytes(modified, "thinking.budget_tokens").Int()
|
||||
if currentBudget != BudgetRectifyBudgetTokens {
|
||||
if result, err := sjson.SetBytes(modified, "thinking.budget_tokens", BudgetRectifyBudgetTokens); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure max_tokens >= BudgetRectifyMinMaxTokens
|
||||
maxTokens := gjson.GetBytes(modified, "max_tokens").Int()
|
||||
if maxTokens < int64(BudgetRectifyMinMaxTokens) {
|
||||
if result, err := sjson.SetBytes(modified, "max_tokens", BudgetRectifyMaxTokens); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified, changed
|
||||
}
|
||||
|
||||
@@ -439,6 +439,210 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||
require.Contains(t, content1["text"], "tool_result")
|
||||
}
|
||||
|
||||
// ============ Group 6b: context_management.edits 清理测试 ============
|
||||
|
||||
// removeThinkingDependentContextStrategies — 边界用例
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_NoContextManagement(t *testing.T) {
|
||||
input := []byte(`{"thinking":{"type":"enabled"},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "无 context_management 字段时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_EmptyEdits(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "edits 为空数组时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_NoClearThinkingEntry(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"other_strategy"}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "edits 中无 clear_thinking_20251015 时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_RemovesSingleEntry(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasEdits := cm["edits"]
|
||||
require.False(t, hasEdits, "所有 edits 均为 clear_thinking_20251015 时应删除 edits 键")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_MixedEntries(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_strategy","param":1}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留其他条目")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "other_strategy", edit0["type"])
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry — 包含 context_management 的场景
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_FastPath(t *testing.T) {
|
||||
// 快速路径:messages 中无 thinking 块,仅有顶层 thinking 字段
|
||||
// 这条路径曾因提前 return 跳过 removeThinkingDependentContextStrategies 而存在 bug
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hello"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasEdits := cm["edits"]
|
||||
require.False(t, hasEdits, "fast path 下 clear_thinking_20251015 应被移除,edits 键应被删除")
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_WithThinkingBlocks(t *testing.T) {
|
||||
// 完整路径:messages 中有 thinking 块(非 fast path)
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"keep_this"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"some thought","signature":"sig"},
|
||||
{"type":"text","text":"Answer"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 keep_this")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "keep_this", edit0["type"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_NoContextManagement_Unaffected(t *testing.T) {
|
||||
// 无 context_management 时不应报错,且 thinking 正常被移除
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"messages":[{"role":"user","content":[{"type":"text","text":"Hi"}]}]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
_, hasCM := req["context_management"]
|
||||
require.False(t, hasCM)
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry — 包含 context_management 的场景
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_RemovesClearThinkingStrategy(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"thought","signature":"sig"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
if rawEdits, hasEdits := cm["edits"]; hasEdits {
|
||||
edits, ok := rawEdits.([]any)
|
||||
require.True(t, ok)
|
||||
for _, e := range edits {
|
||||
em, ok := e.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "clear_thinking_20251015", em["type"], "clear_thinking_20251015 应被移除")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_PreservesNonThinkingStrategies(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_edit"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"t","signature":"s"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 other_edit")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "other_edit", edit0["type"])
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_NoThinkingField_ContextManagementUntouched(t *testing.T) {
|
||||
// 没有顶层 thinking 字段时,context_management 不应被修改
|
||||
input := []byte(`{
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"t","signature":"s"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "无顶层 thinking 时 context_management 不应被修改")
|
||||
}
|
||||
|
||||
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
|
||||
|
||||
// Task 7.1 — 类型校验边界测试
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 40 * 1024 * 1024
|
||||
defaultMaxLineSize = 500 * 1024 * 1024
|
||||
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
|
||||
// to match real Claude CLI traffic as closely as possible. When we need a visual
|
||||
// separator between system blocks, we add "\n\n" at concatenation time.
|
||||
@@ -501,33 +501,35 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
settingService *SettingService
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -552,6 +554,7 @@ func NewGatewayService(
|
||||
sessionLimitCache SessionLimitCache,
|
||||
rpmCache RPMCache,
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@@ -578,10 +581,18 @@ func NewGatewayService(
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
svc.userGroupRateCache,
|
||||
userGroupRateTTL,
|
||||
&svc.userGroupRateSF,
|
||||
"service.gateway",
|
||||
)
|
||||
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
|
||||
return svc
|
||||
@@ -986,6 +997,11 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
||||
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
|
||||
}
|
||||
|
||||
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
|
||||
func GenerateSessionUUID(seed string) string {
|
||||
return generateSessionUUID(seed)
|
||||
}
|
||||
|
||||
func generateSessionUUID(seed string) string {
|
||||
if seed == "" {
|
||||
return uuid.NewString()
|
||||
@@ -4057,7 +4073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@@ -4174,7 +4190,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
// 不是签名错误(或整流器已关闭),继续检查 budget 约束
|
||||
errMsg := extractUpstreamErrorMessage(respBody)
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
|
||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = budgetRetryResp
|
||||
break
|
||||
}
|
||||
if budgetRetryResp != nil && budgetRetryResp.Body != nil {
|
||||
_ = budgetRetryResp.Body.Close()
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
@@ -4266,7 +4320,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -4296,7 +4354,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
@@ -4531,7 +4593,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -4561,7 +4627,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -5276,6 +5346,19 @@ func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||
return m
|
||||
}
|
||||
|
||||
// containsBetaToken checks if a comma-separated header value contains the given token.
|
||||
func containsBetaToken(header, token string) bool {
|
||||
if header == "" || token == "" {
|
||||
return false
|
||||
}
|
||||
for _, p := range strings.Split(header, ",") {
|
||||
if strings.TrimSpace(p) == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildBetaTokenSet(tokens []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(tokens))
|
||||
for _, t := range tokens {
|
||||
@@ -5417,6 +5500,11 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
return m
|
||||
}
|
||||
|
||||
// ChatGPT 内部 API 风格:{"detail":"..."}
|
||||
if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" {
|
||||
return d
|
||||
}
|
||||
|
||||
// 兜底:尝试顶层 message
|
||||
return gjson.GetBytes(body, "message").String()
|
||||
}
|
||||
@@ -6332,63 +6420,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
}
|
||||
|
||||
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||
if s == nil || userID <= 0 || groupID <= 0 {
|
||||
if s == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier
|
||||
}
|
||||
}
|
||||
resolver := s.userGroupRateResolver
|
||||
if resolver == nil {
|
||||
resolver = newUserGroupRateResolver(
|
||||
s.userGroupRateRepo,
|
||||
s.userGroupRateCache,
|
||||
resolveUserGroupRateCacheTTL(s.cfg),
|
||||
&s.userGroupRateSF,
|
||||
"service.gateway",
|
||||
)
|
||||
}
|
||||
if s.userGroupRateRepo == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
userGroupRateCacheMissTotal.Add(1)
|
||||
|
||||
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userGroupRateCacheLoadTotal.Add(1)
|
||||
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
|
||||
if repoErr != nil {
|
||||
return nil, repoErr
|
||||
}
|
||||
multiplier := groupDefaultMultiplier
|
||||
if userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
if s.userGroupRateCache != nil {
|
||||
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
|
||||
}
|
||||
return multiplier, nil
|
||||
})
|
||||
if shared {
|
||||
userGroupRateCacheSFSharedTotal.Add(1)
|
||||
}
|
||||
if err != nil {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
multiplier, ok := value.(float64)
|
||||
if !ok {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
return multiplier
|
||||
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
@@ -6463,7 +6508,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -6954,7 +6999,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
|
||||
@@ -57,6 +57,10 @@ type Group struct {
|
||||
// 分组排序
|
||||
SortOrder int
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
DefaultMappedModel string
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -687,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
@@ -705,16 +709,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
|
||||
@@ -12,6 +12,78 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAISnapshotCacheStub struct {
|
||||
SchedulerCache
|
||||
snapshotAccounts []*Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
|
||||
if len(s.snapshotAccounts) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
out := make([]*Account, 0, len(s.snapshotAccounts))
|
||||
for _, account := range s.snapshotAccounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
cloned := *account
|
||||
out = append(out, &cloned)
|
||||
}
|
||||
return out, true, nil
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.accountsByID == nil {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.accountsByID[accountID]
|
||||
if account == nil {
|
||||
return nil, nil
|
||||
}
|
||||
cloned := *account
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10101)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(31002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10102)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(32002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed prompts/codex_cli_instructions.md
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt-5.4-none": "gpt-5.4",
|
||||
@@ -77,7 +73,7 @@ type codexTransformResult struct {
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||
@@ -95,15 +91,26 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
|
||||
result.NormalizedModel = normalizedModel
|
||||
}
|
||||
|
||||
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
|
||||
// 避免上游返回 "Store must be set to false"。
|
||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||
reqBody["store"] = false
|
||||
result.Modified = true
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||
reqBody["stream"] = true
|
||||
result.Modified = true
|
||||
if isCompact {
|
||||
if _, ok := reqBody["store"]; ok {
|
||||
delete(reqBody, "store")
|
||||
result.Modified = true
|
||||
}
|
||||
if _, ok := reqBody["stream"]; ok {
|
||||
delete(reqBody, "stream")
|
||||
result.Modified = true
|
||||
}
|
||||
} else {
|
||||
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
|
||||
// 避免上游返回 "Store must be set to false"。
|
||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||
reqBody["store"] = false
|
||||
result.Modified = true
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||
reqBody["stream"] = true
|
||||
result.Modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// Strip parameters unsupported by codex models via the Responses API.
|
||||
@@ -219,72 +226,13 @@ func getNormalizedCodexModel(modelID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOpenCodeCodexHeader() string {
|
||||
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
|
||||
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
func getCodexCLIInstructions() string {
|
||||
return codexCLIInstructions
|
||||
}
|
||||
|
||||
func GetOpenCodeInstructions() string {
|
||||
return getOpenCodeCodexHeader()
|
||||
}
|
||||
|
||||
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
|
||||
func GetCodexCLIInstructions() string {
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
// applyInstructions 处理 instructions 字段
|
||||
// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令)
|
||||
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
|
||||
// applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。
|
||||
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
if isCodexCLI {
|
||||
return applyCodexCLIInstructions(reqBody)
|
||||
}
|
||||
return applyOpenCodeInstructions(reqBody)
|
||||
}
|
||||
|
||||
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
|
||||
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
|
||||
func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||
if !isInstructionsEmpty(reqBody) {
|
||||
return false // 已有有效 instructions,不修改
|
||||
return false
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
|
||||
// 优先使用内置 Codex CLI 指令覆盖
|
||||
func applyOpenCodeInstructions(reqBody map[string]any) bool {
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||
|
||||
if instructions != "" {
|
||||
if existingInstructions != instructions {
|
||||
reqBody["instructions"] = instructions
|
||||
return true
|
||||
}
|
||||
} else if existingInstructions == "" {
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions != "" {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
reqBody["instructions"] = "You are a helpful coding assistant."
|
||||
return true
|
||||
}
|
||||
|
||||
// isInstructionsEmpty 检查 instructions 字段是否为空
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
// 未显式设置 store=true,默认为 false。
|
||||
store, ok := reqBody["store"].(bool)
|
||||
@@ -53,7 +53,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
@@ -72,13 +72,29 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CompactForcesNonStreaming(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1-codex",
|
||||
"store": true,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true, true)
|
||||
|
||||
_, hasStore := reqBody["store"]
|
||||
require.False(t, hasStore)
|
||||
_, hasStream := reqBody["stream"]
|
||||
require.False(t, hasStream)
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||
|
||||
@@ -89,7 +105,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
@@ -138,7 +154,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
tools, ok := reqBody["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
@@ -158,7 +174,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
@@ -193,7 +209,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *test
|
||||
"instructions": "existing instructions",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||
result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
@@ -210,7 +226,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
|
||||
// 没有 instructions 字段
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||
result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
@@ -218,20 +234,19 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:已有 instructions 时保留客户端的值,不再覆盖
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"instructions": "old instructions",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false
|
||||
applyCodexOAuthTransform(reqBody, false, false) // isCodexCLI=false
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "old instructions", instructions)
|
||||
require.True(t, result.Modified)
|
||||
require.Equal(t, "old instructions", instructions)
|
||||
}
|
||||
|
||||
func TestIsInstructionsEmpty(t *testing.T) {
|
||||
|
||||
574
backend/internal/service/openai_gateway_messages.go
Normal file
574
backend/internal/service/openai_gateway_messages.go
Normal file
@@ -0,0 +1,574 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ForwardAsAnthropic accepts an Anthropic Messages request body, converts it
|
||||
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||||
// the response back to Anthropic Messages format. This enables Claude Code
|
||||
// clients to access OpenAI models through the standard /v1/messages endpoint.
|
||||
func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
promptCacheKey string,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse Anthropic request
|
||||
var anthropicReq apicompat.AnthropicRequest
|
||||
if err := json.Unmarshal(body, &anthropicReq); err != nil {
|
||||
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
||||
}
|
||||
originalModel := anthropicReq.Model
|
||||
clientStream := anthropicReq.Stream // client's original stream preference
|
||||
|
||||
// 2. Convert Anthropic → Responses
|
||||
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
|
||||
}
|
||||
|
||||
// Upstream always uses streaming (upstream may not support sync mode).
|
||||
// The client's original preference determines the response format.
|
||||
responsesReq.Stream = true
|
||||
isStream := true
|
||||
|
||||
// 2b. Handle BetaFastMode → service_tier: "priority"
|
||||
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
|
||||
responsesReq.ServiceTier = "priority"
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("mapped_model", mappedModel),
|
||||
zap.Bool("stream", isStream),
|
||||
)
|
||||
|
||||
// 4. Marshal Responses request body, then apply OAuth codex transform
|
||||
responsesBody, err := json.Marshal(responsesReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth {
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
// OAuth codex transform forces stream=true upstream, so always use
|
||||
// the streaming response handler regardless of what the client asked.
|
||||
isStream = true
|
||||
responsesBody, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// Override session_id with a deterministic UUID derived from the sticky
|
||||
// session key (buildUpstreamRequest may have set it to the raw value).
|
||||
if promptCacheKey != "" {
|
||||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||
}
|
||||
|
||||
// 7. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 8. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
// Non-failover error: return Anthropic-formatted error to client
|
||||
return s.handleAnthropicErrorResponse(resp, c, account)
|
||||
}
|
||||
|
||||
// 9. Handle normal response
|
||||
// Upstream is always streaming; choose response format based on client preference.
|
||||
var result *OpenAIForwardResult
|
||||
var handleErr error
|
||||
if clientStream {
|
||||
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
} else {
|
||||
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
|
||||
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
}
|
||||
|
||||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||
if handleErr == nil && result != nil {
|
||||
if responsesReq.ServiceTier != "" {
|
||||
st := responsesReq.ServiceTier
|
||||
result.ServiceTier = &st
|
||||
}
|
||||
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||||
re := responsesReq.Reasoning.Effort
|
||||
result.ReasoningEffort = &re
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if handleErr == nil && account.Type == AccountTypeOAuth {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
return result, handleErr
|
||||
}
|
||||
|
||||
// handleAnthropicErrorResponse reads an upstream error and returns it in
|
||||
// Anthropic error format.
|
||||
func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
if upstreamMsg == "" {
|
||||
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
// Record upstream error details for ops logging
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
|
||||
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
|
||||
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
||||
c, account.Platform, resp.StatusCode, body,
|
||||
http.StatusBadGateway, "api_error", "Upstream request failed",
|
||||
); matched {
|
||||
writeAnthropicError(c, status, errType, errMsg)
|
||||
if upstreamMsg == "" {
|
||||
upstreamMsg = errMsg
|
||||
}
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
errType := "api_error"
|
||||
switch {
|
||||
case resp.StatusCode == 400:
|
||||
errType = "invalid_request_error"
|
||||
case resp.StatusCode == 404:
|
||||
errType = "not_found_error"
|
||||
case resp.StatusCode == 429:
|
||||
errType = "rate_limit_error"
|
||||
case resp.StatusCode >= 500:
|
||||
errType = "api_error"
|
||||
}
|
||||
|
||||
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
|
||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
||||
// the upstream streaming response, finds the terminal event (response.completed
|
||||
// / response.incomplete / response.failed), converts the complete response to
|
||||
// Anthropic Messages JSON format, and writes it to the client.
|
||||
// This is used when the client requested stream=false but the upstream is always
|
||||
// streaming.
|
||||
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||||
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||||
}
|
||||
|
||||
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||
// converts each to Anthropic SSE events, and writes them to the client.
|
||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||
// pattern to send Anthropic ping events during periods of upstream silence,
|
||||
// preventing proxy/client timeout disconnections.
|
||||
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
state := apicompat.NewResponsesEventToAnthropicState()
|
||||
state.Model = originalModel
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
// resultWithUsage builds the final result snapshot.
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}
|
||||
}
|
||||
|
||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||
// Returns (clientDisconnected bool).
|
||||
processDataLine := func(payload string) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to Anthropic events
|
||||
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai messages stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
// handleScanErr logs scanner errors if meaningful.
|
||||
handleScanErr := func(err error) {
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages stream: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Determine keepalive interval ──
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
|
||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||
if keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
// ── With keepalive: goroutine + channel + select ──
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// Upstream closed
|
||||
return finalizeStream()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
// Send Anthropic-format ping event
|
||||
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
|
||||
// Client disconnected
|
||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeAnthropicError writes an error response in Anthropic Messages API format.
|
||||
func writeAnthropicError(c *gin.Context, statusCode int, errType, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
558
backend/internal/service/openai_gateway_record_usage_test.go
Normal file
558
backend/internal/service/openai_gateway_record_usage_test.go
Normal file
@@ -0,0 +1,558 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIRecordUsageLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.calls++
|
||||
s.lastLog = log
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
deductCalls int
|
||||
deductErr error
|
||||
lastAmount float64
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
s.deductCalls++
|
||||
s.lastAmount = amount
|
||||
return s.deductErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageSubRepoStub struct {
|
||||
UserSubscriptionRepository
|
||||
|
||||
incrementCalls int
|
||||
incrementErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
s.incrementCalls++
|
||||
return s.incrementErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.quotaCalls++
|
||||
s.lastAmount = cost
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.rateLimitCalls++
|
||||
s.lastAmount = cost
|
||||
return s.err
|
||||
}
|
||||
|
||||
type openAIUserGroupRateRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
|
||||
rate *float64
|
||||
err error
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
s.calls++
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.rate, nil
|
||||
}
|
||||
|
||||
func i64p(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
|
||||
return &OpenAIGatewayService{
|
||||
usageLogRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: subRepo,
|
||||
cfg: cfg,
|
||||
billingService: NewBillingService(cfg, nil),
|
||||
billingCacheService: &BillingCacheService{},
|
||||
deferredService: &DeferredService{},
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||
t.Helper()
|
||||
|
||||
cost, err := svc.billingService.CalculateCost(model, UsageTokens{
|
||||
InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0),
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: usage.CacheReadInputTokens,
|
||||
}, multiplier)
|
||||
require.NoError(t, err)
|
||||
return cost
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||
groupID := int64(11)
|
||||
groupRate := 1.4
|
||||
userRate := 1.8
|
||||
usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3}
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_user_group_rate",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1001,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: groupRate,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2001},
|
||||
Account: &Account{ID: 3001},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, rateRepo.calls)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier)
|
||||
require.Equal(t, 12, usageRepo.lastLog.InputTokens)
|
||||
require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens)
|
||||
|
||||
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate)
|
||||
require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
|
||||
groupID := int64(12)
|
||||
groupRate := 1.6
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2}
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_group_default_on_error",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1002,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: groupRate,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2002},
|
||||
Account: &Account{ID: 3002},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, rateRepo.calls)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
|
||||
|
||||
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate)
|
||||
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) {
|
||||
groupID := int64(13)
|
||||
groupRate := 1.25
|
||||
usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1}
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.userGroupRateResolver = nil
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_group_default_nil_resolver",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1003,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: groupRate,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2003},
|
||||
Account: &Account{ID: 3003},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_duplicate",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1004},
|
||||
User: &User{ID: 2004},
|
||||
Account: &Account{ID: 3004},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_quota_update",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1005,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 2005},
|
||||
Account: &Account{ID: 3005},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1)
|
||||
require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_clamp_actual_input",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 2,
|
||||
OutputTokens: 1,
|
||||
CacheReadInputTokens: 5,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1006},
|
||||
User: &User{ID: 2006},
|
||||
Account: &Account{ID: 3006},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_gpt54_long_context",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 300000,
|
||||
OutputTokens: 2000,
|
||||
},
|
||||
Model: "gpt-5.4-2026-03-05",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1014},
|
||||
User: &User{ID: 2014},
|
||||
Account: &Account{ID: 3014},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
|
||||
expectedInput := 300000 * 2.5e-6 * 2.0
|
||||
expectedOutput := 2000 * 15e-6 * 1.5
|
||||
require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "priority"
|
||||
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_service_tier_priority",
|
||||
ServiceTier: &serviceTier,
|
||||
Usage: usage,
|
||||
Model: "gpt-5.4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1015},
|
||||
User: &User{ID: 2015},
|
||||
Account: &Account{ID: 3015},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
|
||||
|
||||
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0)
|
||||
require.NoError(t, calcErr)
|
||||
require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "flex"
|
||||
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_service_tier_flex",
|
||||
ServiceTier: &serviceTier,
|
||||
Usage: usage,
|
||||
Model: "gpt-5.4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1016},
|
||||
User: &User{ID: 2016},
|
||||
Account: &Account{ID: 3016},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
|
||||
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0)
|
||||
require.NoError(t, calcErr)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
||||
t.Run("fast maps to priority", func(t *testing.T) {
|
||||
got := normalizeOpenAIServiceTier(" fast ")
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "priority", *got)
|
||||
})
|
||||
|
||||
t.Run("default ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
||||
})
|
||||
|
||||
t.Run("invalid ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "priority"
|
||||
reasoning := "high"
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_billing_model_override",
|
||||
BillingModel: "gpt-5.1-codex",
|
||||
Model: "gpt-5.1",
|
||||
ServiceTier: &serviceTier,
|
||||
ReasoningEffort: &reasoning,
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: 2 * time.Second,
|
||||
FirstTokenMs: func() *int { v := 120; return &v }(),
|
||||
},
|
||||
APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
UserAgent: "codex-cli/1.0",
|
||||
IPAddress: "127.0.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
|
||||
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort)
|
||||
require.NotNil(t, usageRepo.lastLog.UserAgent)
|
||||
require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent)
|
||||
require.NotNil(t, usageRepo.lastLog.IPAddress)
|
||||
require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress)
|
||||
require.NotNil(t, usageRepo.lastLog.GroupID)
|
||||
require.Equal(t, int64(11), *usageRepo.lastLog.GroupID)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
subscription := &UserSubscription{ID: 99}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_subscription_billing",
|
||||
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
|
||||
User: &User{ID: 200},
|
||||
Account: &Account{ID: 300},
|
||||
Subscription: subscription,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType)
|
||||
require.NotNil(t, usageRepo.lastLog.SubscriptionID)
|
||||
require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID)
|
||||
require.Equal(t, 1, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.cfg.RunMode = config.RunModeSimple
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_simple_mode",
|
||||
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1000},
|
||||
User: &User{ID: 2000},
|
||||
Account: &Account{ID: 3000},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user