mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5831eb8a6a | ||
|
|
61838cdb3d | ||
|
|
50dba656fd | ||
|
|
0e2821456c | ||
|
|
f25ac3aff5 | ||
|
|
f6341b7f2b | ||
|
|
4e257512b9 | ||
|
|
e53b34f321 | ||
|
|
12ddae0184 | ||
|
|
7b9c3f165e | ||
|
|
0b8e84f942 | ||
|
|
d9e27df9af | ||
|
|
f0fabf89a1 | ||
|
|
5bbfbcdae9 | ||
|
|
eb55947ec4 | ||
|
|
5f7e5184eb | ||
|
|
008a111268 | ||
|
|
fda753278c | ||
|
|
6c469b42ed | ||
|
|
dacf3a2a6e | ||
|
|
e6add93ae3 | ||
|
|
b2273ec695 | ||
|
|
aa89777dda | ||
|
|
1e1f3c0c74 | ||
|
|
1fab9204eb | ||
|
|
dbd3e71637 | ||
|
|
974f67211b | ||
|
|
0338c83b90 | ||
|
|
c6b3de1199 | ||
|
|
f1325e9ae6 | ||
|
|
587012396b |
38
.github/workflows/backend-ci.yml
vendored
Normal file
38
.github/workflows/backend-ci.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: Run tests
|
||||
working-directory: backend
|
||||
run: go test ./...
|
||||
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
args: --timeout=5m
|
||||
working-directory: backend
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -28,6 +28,7 @@ node_modules/
|
||||
frontend/node_modules/
|
||||
frontend/dist/
|
||||
*.local
|
||||
*.tsbuildinfo
|
||||
|
||||
# 日志
|
||||
npm-debug.log*
|
||||
@@ -81,7 +82,12 @@ build/
|
||||
release/
|
||||
|
||||
# 后端嵌入的前端构建产物
|
||||
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
|
||||
# while still ignoring generated frontend build outputs.
|
||||
backend/internal/web/dist/
|
||||
!backend/internal/web/dist/
|
||||
backend/internal/web/dist/*
|
||||
!backend/internal/web/dist/.keep
|
||||
|
||||
# 后端运行时缓存数据
|
||||
backend/data/
|
||||
@@ -92,4 +98,4 @@ backend/data/
|
||||
tests
|
||||
CLAUDE.md
|
||||
.claude
|
||||
scripts
|
||||
scripts
|
||||
|
||||
@@ -11,6 +11,8 @@ builds:
|
||||
dir: backend
|
||||
main: ./cmd/server
|
||||
binary: sub2api
|
||||
flags:
|
||||
- -tags=embed
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -40,14 +40,15 @@ WORKDIR /app/backend
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy frontend dist from previous stage
|
||||
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Copy backend source
|
||||
# Copy backend source first
|
||||
COPY backend/ ./
|
||||
|
||||
# Build the binary (BuildType=release for CI builds)
|
||||
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
|
||||
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-tags embed \
|
||||
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||
-o /app/sub2api \
|
||||
./cmd/server
|
||||
|
||||
16
README.md
16
README.md
@@ -220,21 +220,21 @@ cd sub2api
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# Output will be in ../backend/internal/web/dist/
|
||||
|
||||
# 3. Copy frontend build to backend (for embedding)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 4. Build backend (requires frontend dist to be present)
|
||||
# 3. Build backend with embedded frontend
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. Create configuration file
|
||||
# 4. Create configuration file
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. Edit configuration
|
||||
# 5. Edit configuration
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
|
||||
|
||||
**Key configuration in `config.yaml`:**
|
||||
|
||||
```yaml
|
||||
@@ -265,7 +265,7 @@ default:
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. Run the application
|
||||
# 6. Run the application
|
||||
./sub2api
|
||||
```
|
||||
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -220,21 +220,21 @@ cd sub2api
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# 构建产物输出到 ../backend/internal/web/dist/
|
||||
|
||||
# 3. 复制前端构建产物到后端(用于嵌入)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 4. 编译后端(需要前端 dist 目录存在)
|
||||
# 3. 编译后端(嵌入前端)
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. 创建配置文件
|
||||
# 4. 创建配置文件
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. 编辑配置
|
||||
# 5. 编辑配置
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
|
||||
|
||||
**`config.yaml` 关键配置:**
|
||||
|
||||
```yaml
|
||||
@@ -265,7 +265,7 @@ default:
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. 运行应用
|
||||
# 6. 运行应用
|
||||
./sub2api
|
||||
```
|
||||
|
||||
|
||||
587
backend/.golangci.yml
Normal file
587
backend/.golangci.yml
Normal file
@@ -0,0 +1,587 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- depguard
|
||||
- errcheck
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- unused
|
||||
|
||||
settings:
|
||||
depguard:
|
||||
rules:
|
||||
# Enforce: service must not depend on repository.
|
||||
service-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- internal/service/**
|
||||
deny:
|
||||
- pkg: sub2api/internal/repository
|
||||
desc: "service must not import repository"
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: true
|
||||
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-blank: false
|
||||
# To disable the errcheck built-in exclude list.
|
||||
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
# Default: false
|
||||
disable-default-exclusions: true
|
||||
# List of functions to exclude from checking, where each entry is a single function to exclude.
|
||||
# See https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
exclude-functions:
|
||||
- io/ioutil.ReadFile
|
||||
- io.Copy(*bytes.Buffer)
|
||||
- io.Copy(os.Stdout)
|
||||
- fmt.Println
|
||||
- fmt.Print
|
||||
- fmt.Printf
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintf
|
||||
- fmt.Fprintln
|
||||
# Display function signature instead of selector.
|
||||
# Default: false
|
||||
verbose: true
|
||||
ineffassign:
|
||||
# Check escaping variables of type error, may cause false positives.
|
||||
# Default: false
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
checks:
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# Incorrect or missing package comment.
|
||||
# https://staticcheck.dev/docs/checks/#ST1000
|
||||
- ST1000
|
||||
# Dot imports are discouraged.
|
||||
# https://staticcheck.dev/docs/checks/#ST1001
|
||||
- ST1001
|
||||
# Poorly chosen identifier.
|
||||
# https://staticcheck.dev/docs/checks/#ST1003
|
||||
- ST1003
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# The documentation of an exported function should start with the function's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1020
|
||||
- ST1020
|
||||
# The documentation of an exported type should start with type's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1021
|
||||
- ST1021
|
||||
# The documentation of an exported variable or constant should start with variable's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1022
|
||||
- ST1022
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
settings:
|
||||
gofmt:
|
||||
# Simplify code: gofmt with `-s` option.
|
||||
# Default: true
|
||||
simplify: false
|
||||
# Apply the rewrite rules to the source before reformatting.
|
||||
# https://pkg.go.dev/cmd/gofmt
|
||||
# Default: []
|
||||
rewrite-rules:
|
||||
- pattern: 'interface{}'
|
||||
replacement: 'any'
|
||||
- pattern: 'a[b:len(a)]'
|
||||
replacement: 'a[b:]'
|
||||
@@ -1,6 +1,16 @@
|
||||
.PHONY: wire
|
||||
.PHONY: wire build build-embed
|
||||
|
||||
wire:
|
||||
@echo "生成 Wire 代码..."
|
||||
@cd cmd/server && go generate
|
||||
@echo "Wire 代码生成完成"
|
||||
@echo "Wire 代码生成完成"
|
||||
|
||||
build:
|
||||
@echo "构建后端(不嵌入前端)..."
|
||||
@go build -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server"
|
||||
|
||||
build-embed:
|
||||
@echo "构建后端(嵌入前端)..."
|
||||
@go build -tags embed -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server (with embedded frontend)"
|
||||
@@ -85,6 +85,14 @@ func provideCleanup(
|
||||
services.EmailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
services.OAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
services.OpenAIOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -48,7 +48,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(authService)
|
||||
userService := service.NewUserService(userRepository, configConfig)
|
||||
userService := service.NewUserService(userRepository)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
@@ -67,22 +67,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
dashboardHandler := admin.NewDashboardHandler(usageLogRepository)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, usageLogRepository)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
@@ -93,7 +96,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
@@ -103,43 +106,47 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
groupService := service.NewGroupService(groupRepository)
|
||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||
proxyService := service.NewProxyService(proxyRepository)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||
services := &service.Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OAuth: oAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OpenAIGateway: openAIGatewayService,
|
||||
OAuth: oAuthService,
|
||||
OpenAIOAuth: openAIOAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
@@ -201,6 +208,14 @@ func provideCleanup(
|
||||
services.EmailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
services.OAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
services.OpenAIOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -52,7 +52,7 @@ type PricingConfig struct {
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func setDefaults() {
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
@@ -210,10 +210,10 @@ func setDefaults() {
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,14 +16,12 @@ import (
|
||||
// OAuthHandler handles OAuth-related operations for accounts
|
||||
type OAuthHandler struct {
|
||||
oauthService *service.OAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
|
||||
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
oauthService: oauthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,47 +29,51 @@ func NewOAuthHandler(oauthService *service.OAuthService, adminService service.Ad
|
||||
type AccountHandler struct {
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
usageLogRepo *repository.UsageLogRepository
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
|
||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, openaiOAuthService *service.OpenAIOAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService, usageLogRepo *repository.UsageLogRepository) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
usageLogRepo: usageLogRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAccountRequest represents create account request
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials" binding:"required"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents update account request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
@@ -234,26 +239,47 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
var newCredentials map[string]any
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials := make(map[string]interface{})
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials = make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
@@ -275,15 +301,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
_ = accountID
|
||||
response.Success(c, gin.H{
|
||||
"total_requests": 0,
|
||||
"successful_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"total_tokens": 0,
|
||||
"average_response_time": 0,
|
||||
})
|
||||
// Parse days parameter (default 30)
|
||||
days := 30
|
||||
if daysStr := c.Query("days"); daysStr != "" {
|
||||
if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
|
||||
days = d
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate time range
|
||||
now := timezone.Now()
|
||||
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
|
||||
|
||||
stats, err := h.usageLogRepo.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get account stats: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// ClearError handles clearing account error
|
||||
@@ -565,6 +602,46 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// Return mapped models
|
||||
var models []openai.Model
|
||||
for requestedModel := range mapping {
|
||||
var found bool
|
||||
for _, dm := range openai.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
models = append(models, openai.Model{
|
||||
ID: requestedModel,
|
||||
Object: "model",
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
response.Success(c, models)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, claude.DefaultModels)
|
||||
@@ -573,7 +650,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// For API Key accounts: return models based on model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
// No mapping configured, return default models
|
||||
response.Success(c, claude.DefaultModels)
|
||||
return
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,17 +12,15 @@ import (
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
adminService service.AdminService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
func NewDashboardHandler(usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
adminService: adminService,
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,7 +175,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, 0)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
@@ -258,7 +255,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -286,7 +283,7 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
|
||||
type OpenAIOAuthHandler struct {
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
|
||||
type OpenAIGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates OpenAI OAuth authorization URL
|
||||
// POST /api/v1/admin/openai/generate-auth-url
|
||||
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req OpenAIGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges OpenAI authorization code for tokens
|
||||
// POST /api/v1/admin/openai/exchange-code
|
||||
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req OpenAIExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to refresh token: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedAccount)
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create account: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
}
|
||||
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
|
||||
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_value_distributed": 0.0,
|
||||
"by_type": gin.H{
|
||||
"balance": 0,
|
||||
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Write data rows
|
||||
for _, code := range codes {
|
||||
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
writer.Write([]string{
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
code.Type,
|
||||
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
usedBy,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||
|
||||
@@ -256,3 +256,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"exists": exists,
|
||||
"masked_key": maskedKey,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"key": key, // 完整 key 只在生成时返回一次
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||
}
|
||||
|
||||
@@ -193,7 +193,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []interface{}{})
|
||||
response.Success(c, []any{})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Wechat string `json:"wechat"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
@@ -35,6 +38,9 @@ type CreateUserRequest struct {
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
@@ -43,8 +49,9 @@ type UpdateUserRequest struct {
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
type UpdateBalanceRequest struct {
|
||||
Balance float64 `json:"balance" binding:"required"`
|
||||
Balance float64 `json:"balance" binding:"required,gt=0"`
|
||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all users with pagination
|
||||
@@ -94,6 +101,9 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
@@ -125,6 +135,9 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
@@ -171,7 +184,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update balance: "+err.Error())
|
||||
return
|
||||
|
||||
@@ -13,24 +13,18 @@ import (
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum wait time for concurrency slot
|
||||
maxConcurrencyWait = 60 * time.Second
|
||||
// Ping interval during wait
|
||||
pingInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
userService *service.UserService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -38,8 +32,8 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
userService: userService,
|
||||
concurrencyService: concurrencyService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +83,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -98,10 +92,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -139,7 +133,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -173,133 +167,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// concurrencyError represents a concurrency limit error with context
|
||||
type concurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *concurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
|
||||
// Note: For streaming requests, we send ping to keep the connection alive.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
|
||||
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// For streaming requests, set up SSE headers for ping
|
||||
var flusher http.Flusher
|
||||
if isStream {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &concurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingTicker.C:
|
||||
// Send ping for streaming requests to keep connection alive
|
||||
if isStream && flusher != nil {
|
||||
// Set headers on first ping (lazy initialization)
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
// GET /v1/models
|
||||
// Returns different model lists based on the API key's group platform
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware.GetApiKeyFromContext(c)
|
||||
|
||||
// Return OpenAI models for OpenAI platform groups
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": openai.DefaultModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Default: Claude models
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": claude.DefaultModels,
|
||||
"object": "list",
|
||||
"data": claude.DefaultModels,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -414,7 +300,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
if ok {
|
||||
// Send error event in SSE format
|
||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
fmt.Fprint(c.Writer, errorEvent)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
@@ -574,11 +462,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
|
||||
180
backend/internal/handler/gateway_helper.go
Normal file
180
backend/internal/handler/gateway_helper.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval is the interval for sending ping events during slot wait
|
||||
pingInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
type SSEPingFormat string
|
||||
|
||||
const (
|
||||
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
||||
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
||||
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
||||
SSEPingFormatNone SSEPingFormat = ""
|
||||
)
|
||||
|
||||
// ConcurrencyError represents a concurrency limit error with context
|
||||
type ConcurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *ConcurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
||||
type ConcurrencyHelper struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
pingFormat SSEPingFormat
|
||||
}
|
||||
|
||||
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
||||
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
||||
return &ConcurrencyHelper{
|
||||
concurrencyService: concurrencyService,
|
||||
pingFormat: pingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait count for a user
|
||||
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Only create ping ticker if ping is needed
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &ConcurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingCh:
|
||||
// Send ping to keep connection alive
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
@@ -21,15 +22,16 @@ type AdminHandlers struct {
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
Setting *SettingHandler
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
209
backend/internal/handler/openai_gateway_handler.go
Normal file
209
backend/internal/handler/openai_gateway_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body to map for potential modification
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// For non-Codex CLI requests, set default instructions
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
reqBody["instructions"] = openai.DefaultInstructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// Ensure wait count is decremented when function exits
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Normal case: return JSON response with proper status code
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -358,7 +358,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,12 @@ type ChangePasswordRequest struct {
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
@@ -47,6 +53,9 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, userData)
|
||||
}
|
||||
|
||||
@@ -83,3 +92,40 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
|
||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// UpdateProfile handles updating user profile
|
||||
// PUT /api/v1/users/me
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to update profile: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, updatedUser)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
@@ -27,6 +28,7 @@ func ProvideAdminHandlers(
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
@@ -56,18 +58,20 @@ func ProvideHandlers(
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +85,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -89,6 +94,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
|
||||
130
backend/internal/middleware/admin_auth.go
Normal file
130
backend/internal/middleware/admin_auth.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminAuth 管理员认证中间件
|
||||
// 支持两种认证方式(通过不同的 header 区分):
|
||||
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||
func AdminAuth(
|
||||
authService *service.AuthService,
|
||||
userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
},
|
||||
settingService *service.SettingService,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userRepo) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization header(JWT 认证)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userRepo) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userRepo interface {
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
},
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
}
|
||||
|
||||
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
|
||||
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
|
||||
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取真实的管理员用户
|
||||
admin, err := userRepo.GetFirstAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), admin)
|
||||
c.Set("auth_method", "admin_api_key")
|
||||
return true
|
||||
}
|
||||
|
||||
// validateJWTForAdmin 验证 JWT 并检查管理员权限
|
||||
func validateJWTForAdmin(
|
||||
c *gin.Context,
|
||||
token string,
|
||||
authService *service.AuthService,
|
||||
userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
},
|
||||
) bool {
|
||||
// 验证 JWT token
|
||||
claims, err := authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
if err == service.ErrTokenExpired {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return false
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return false
|
||||
}
|
||||
|
||||
// 从数据库获取用户
|
||||
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if user.Role != model.RoleAdmin {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), user)
|
||||
c.Set("auth_method", "jwt")
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
// JSONB 用于存储JSONB数据
|
||||
type JSONB map[string]interface{}
|
||||
type JSONB map[string]any
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
func (j *JSONB) Scan(value any) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
@@ -40,8 +40,8 @@ type Account struct {
|
||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
@@ -68,7 +68,8 @@ type Account struct {
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
Groups []*Group `gorm:"-" json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func (Account) TableName() string {
|
||||
@@ -145,7 +146,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
// 处理map[string]interface{}类型
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
@@ -163,7 +164,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
// 如果没有设置模型映射,则支持所有模型
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
return true // 没有映射配置,支持所有模型
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
@@ -174,7 +175,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// 如果没有映射,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
@@ -231,7 +232,7 @@ func (a *Account) GetCustomErrorCodes() []int {
|
||||
return nil
|
||||
}
|
||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
if arr, ok := raw.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
// JSON 数字默认解析为 float64
|
||||
@@ -277,3 +278,138 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// =============== OpenAI 相关方法 ===============
|
||||
|
||||
// IsOpenAI 检查是否为 OpenAI 平台账号
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
// IsAnthropic 检查是否为 Anthropic 平台账号
|
||||
func (a *Account) IsAnthropic() bool {
|
||||
return a.Platform == PlatformAnthropic
|
||||
}
|
||||
|
||||
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
|
||||
func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
}
|
||||
|
||||
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
|
||||
// 对于 API Key 类型账号,从 credentials 中获取 base_url
|
||||
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
}
|
||||
}
|
||||
return "https://api.openai.com" // OpenAI 默认 API URL
|
||||
}
|
||||
|
||||
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
|
||||
func (a *Account) GetOpenAIAccessToken() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("access_token")
|
||||
}
|
||||
|
||||
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
|
||||
func (a *Account) GetOpenAIRefreshToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("refresh_token")
|
||||
}
|
||||
|
||||
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
|
||||
func (a *Account) GetOpenAIIDToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
}
|
||||
|
||||
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
|
||||
// 返回空字符串表示透传原始 User-Agent
|
||||
func (a *Account) GetOpenAIUserAgent() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("user_agent")
|
||||
}
|
||||
|
||||
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTAccountID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_user_id")
|
||||
}
|
||||
|
||||
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
|
||||
func (a *Account) GetOpenAIOrganizationID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("organization_id")
|
||||
}
|
||||
|
||||
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
|
||||
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
expiresAtStr := a.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return nil
|
||||
}
|
||||
// 尝试解析时间
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
// 尝试解析为 Unix 时间戳
|
||||
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||
t = time.Unix(int64(v), 0)
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false // 没有过期时间信息,假设未过期
|
||||
}
|
||||
// 提前 60 秒认为过期,便于刷新
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
|
||||
@@ -13,13 +13,13 @@ const (
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
|
||||
// 订阅功能字段
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||
|
||||
@@ -9,15 +9,16 @@ import (
|
||||
type RedeemCode struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `gorm:"type:text" json:"notes"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 订阅类型专用字段
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||
|
||||
// 关联
|
||||
@@ -40,8 +41,10 @@ func (r *RedeemCode) CanUse() bool {
|
||||
}
|
||||
|
||||
// GenerateRedeemCode 生成唯一的兑换码
|
||||
func GenerateRedeemCode() string {
|
||||
func GenerateRedeemCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
@@ -19,17 +19,17 @@ func (Setting) TableName() string {
|
||||
// 设置Key常量
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -46,8 +46,14 @@ const (
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
)
|
||||
|
||||
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
|
||||
// SystemSettings 系统设置结构体(用于API响应)
|
||||
type SystemSettings struct {
|
||||
// 注册设置
|
||||
|
||||
@@ -37,7 +37,7 @@ type UsageLog struct {
|
||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
Username string `gorm:"size:100;default:''" json:"username"`
|
||||
Wechat string `gorm:"size:100;default:''" json:"wechat"`
|
||||
Notes string `gorm:"type:text;default:''" json:"notes"`
|
||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||
@@ -22,7 +25,8 @@ type User struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
|
||||
@@ -43,18 +43,25 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
42
backend/internal/pkg/openai/constants.go
Normal file
42
backend/internal/pkg/openai/constants.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
// Model represents an OpenAI model
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||
}
|
||||
|
||||
// DefaultModelIDs returns the default model ID list
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel default model for testing OpenAI accounts
|
||||
const DefaultTestModel = "gpt-5.1-codex"
|
||||
|
||||
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||
// Content loaded from instructions.txt at compile time
|
||||
//
|
||||
//go:embed instructions.txt
|
||||
var DefaultInstructions string
|
||||
118
backend/internal/pkg/openai/instructions.txt
Normal file
118
backend/internal/pkg/openai/instructions.txt
Normal file
@@ -0,0 +1,118 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No \"save/copy this file\" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no \"above/below\"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
|
||||
|
||||
366
backend/internal/pkg/openai/oauth.go
Normal file
366
backend/internal/pkg/openai/oauth.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
// Default redirect URI (can be customized)
|
||||
DefaultRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// Scopes
|
||||
DefaultScopes = "openid profile email offline_access"
|
||||
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
|
||||
RefreshScopes = "openid profile email"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
|
||||
// OpenAI uses hex encoding instead of base64url
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
// Uses base64url encoding as per RFC 7636
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OpenAI OAuth
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IDTokenClaims represents the claims from OpenAI ID Token
|
||||
type IDTokenClaims struct {
|
||||
// Standard claims
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Iss string `json:"iss"`
|
||||
Aud []string `json:"aud"` // OpenAI returns aud as an array
|
||||
Exp int64 `json:"exp"`
|
||||
Iat int64 `json:"iat"`
|
||||
|
||||
// OpenAI specific claims (nested under https://api.openai.com/auth)
|
||||
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIAuthClaims represents the OpenAI specific auth claims
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
|
||||
// OrganizationClaim represents an organization in the ID Token
|
||||
type OrganizationClaim struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Title string `json:"title"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request for OpenAI
|
||||
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
Scope: RefreshScopes,
|
||||
}
|
||||
}
|
||||
|
||||
// ToFormData converts TokenRequest to URL-encoded form data
|
||||
func (r *TokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("code", r.Code)
|
||||
params.Set("redirect_uri", r.RedirectURI)
|
||||
params.Set("code_verifier", r.CodeVerifier)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ToFormData converts RefreshTokenRequest to URL-encoded form data
|
||||
func (r *RefreshTokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("refresh_token", r.RefreshToken)
|
||||
params.Set("scope", r.Scope)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if necessary
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try standard encoding
|
||||
decoded, err = base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var claims IDTokenClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
}
|
||||
|
||||
// GetUserInfo extracts user info from ID Token claims
|
||||
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
info := &UserInfo{
|
||||
Email: c.Email,
|
||||
}
|
||||
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
// Get default organization ID
|
||||
for _, org := range c.OpenAIAuth.Organizations {
|
||||
if org.IsDefault {
|
||||
info.OrganizationID = org.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no default, use first org
|
||||
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
|
||||
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
18
backend/internal/pkg/openai/request.go
Normal file
18
backend/internal/pkg/openai/request.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package openai
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_vscode/",
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -9,22 +9,22 @@ import (
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items interface{} `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
Items any `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
func Success(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -33,7 +33,7 @@ func Success(c *gin.Context, data interface{}) {
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
func Created(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -75,7 +75,7 @@ func InternalError(c *gin.Context, message string) {
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
|
||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
@@ -99,7 +99,7 @@ type PaginationResult struct {
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
|
||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
|
||||
@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
|
||||
|
||||
func TestTimeNowAffected(t *testing.T) {
|
||||
// Reset to UTC first
|
||||
Init("UTC")
|
||||
if err := Init("UTC"); err != nil {
|
||||
t.Fatalf("Init failed with UTC: %v", err)
|
||||
}
|
||||
utcNow := time.Now()
|
||||
|
||||
// Switch to Shanghai (UTC+8)
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
shanghaiNow := time.Now()
|
||||
|
||||
// The times should be the same instant, but different timezone representation
|
||||
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToday(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
today := Today()
|
||||
now := Now()
|
||||
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
// Create a time at 15:30:45
|
||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
|
||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||
// and why StartOfDay is more reliable for timezone-aware code
|
||||
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
now := Now()
|
||||
|
||||
|
||||
@@ -23,14 +23,18 @@ func (r *AccountRepository) Create(ctx context.Context, account *model.Account)
|
||||
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 填充 GroupIDs 虚拟字段
|
||||
// 填充 GroupIDs 和 Groups 虚拟字段
|
||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
|
||||
for _, ag := range account.AccountGroups {
|
||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
account.Groups = append(account.Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -78,15 +82,19 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 填充每个 Account 的 GroupIDs 虚拟字段
|
||||
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups)
|
||||
for i := range accounts {
|
||||
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
||||
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
|
||||
for _, ag := range accounts[i].AccountGroups {
|
||||
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +139,7 @@ func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
|
||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusError,
|
||||
"error_message": errorMsg,
|
||||
}).Error
|
||||
@@ -222,11 +230,43 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByPlatform 按平台获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ?", platform).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.platform = ?", platform).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// SetRateLimited 标记账号为限流状态(429)
|
||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
@@ -241,7 +281,7 @@ func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": nil,
|
||||
"rate_limit_reset_at": nil,
|
||||
"overload_until": nil,
|
||||
@@ -250,7 +290,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
|
||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]interface{}{
|
||||
updates := map[string]any{
|
||||
"session_window_status": status,
|
||||
}
|
||||
if start != nil {
|
||||
@@ -267,3 +307,31 @@ func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("schedulable", schedulable).Error
|
||||
}
|
||||
|
||||
// UpdateExtra updates specific fields in account's Extra JSONB field
|
||||
// It merges the updates into existing Extra data without overwriting other fields
|
||||
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current account to preserve existing Extra data
|
||||
var account model.Account
|
||||
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Extra if nil
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(model.JSONB)
|
||||
}
|
||||
|
||||
// Merge updates into existing Extra
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
|
||||
// Save updated Extra
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("extra", account.Extra).Error
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/oauth"
|
||||
@@ -64,7 +65,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
@@ -139,23 +140,15 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if len(code) > 0 {
|
||||
parts := make([]string, 0, 2)
|
||||
for i, part := range []rune(code) {
|
||||
if part == '#' {
|
||||
authCode = code[:i]
|
||||
codeState = code[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
authCode = code
|
||||
}
|
||||
if idx := strings.Index(code, "#"); idx != -1 {
|
||||
authCode = code[:idx]
|
||||
codeState = code[idx+1:]
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
|
||||
@@ -19,7 +19,11 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get default transport")
|
||||
}
|
||||
transport = transport.Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
@@ -43,7 +47,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
|
||||
@@ -6,15 +6,18 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
type claudeUpstreamService struct {
|
||||
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
|
||||
type httpUpstreamService struct {
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
@@ -27,13 +30,13 @@ func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &claudeUpstreamService{
|
||||
return &httpUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
@@ -41,7 +44,7 @@ func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Re
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
92
backend/internal/repository/openai_oauth_service.go
Normal file
92
backend/internal/repository/openai_oauth_service.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type openaiOAuthService struct{}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{}
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("code", code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(openai.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(openai.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
|
||||
@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
|
||||
now := time.Now()
|
||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusUsed,
|
||||
"used_by": userID,
|
||||
"used_at": now,
|
||||
|
||||
@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
var result service.TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
|
||||
@@ -937,7 +937,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]ModelStat, error) {
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
@@ -958,6 +958,9 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
if apiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
db = db.Where("account_id = ?", accountID)
|
||||
}
|
||||
|
||||
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
|
||||
if err != nil {
|
||||
@@ -1007,3 +1010,209 @@ func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageSummary represents summary statistics for an account
|
||||
type AccountUsageSummary struct {
|
||||
Days int `json:"days"`
|
||||
ActualDaysUsed int `json:"actual_days_used"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||
Today *struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
} `json:"today"`
|
||||
HighestCostDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
} `json:"highest_cost_day"`
|
||||
HighestRequestDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
} `json:"highest_request_day"`
|
||||
}
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
}
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
if daysCount <= 0 {
|
||||
daysCount = 30
|
||||
}
|
||||
|
||||
// Get daily history
|
||||
var historyResults []struct {
|
||||
Date string `gorm:"column:date"`
|
||||
Requests int64 `gorm:"column:requests"`
|
||||
Tokens int64 `gorm:"column:tokens"`
|
||||
Cost float64 `gorm:"column:cost"`
|
||||
ActualCost float64 `gorm:"column:actual_cost"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`).
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
Group("date").
|
||||
Order("date ASC").
|
||||
Scan(&historyResults).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build history with labels
|
||||
history := make([]AccountUsageHistory, 0, len(historyResults))
|
||||
for _, h := range historyResults {
|
||||
// Parse date to get label (MM/DD)
|
||||
t, _ := time.Parse("2006-01-02", h.Date)
|
||||
label := t.Format("01/02")
|
||||
history = append(history, AccountUsageHistory{
|
||||
Date: h.Date,
|
||||
Label: label,
|
||||
Requests: h.Requests,
|
||||
Tokens: h.Tokens,
|
||||
Cost: h.Cost,
|
||||
ActualCost: h.ActualCost,
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate summary
|
||||
var totalActualCost, totalStandardCost float64
|
||||
var totalRequests, totalTokens int64
|
||||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||||
|
||||
for i := range history {
|
||||
h := &history[i]
|
||||
totalActualCost += h.ActualCost
|
||||
totalStandardCost += h.Cost
|
||||
totalRequests += h.Requests
|
||||
totalTokens += h.Tokens
|
||||
|
||||
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
|
||||
highestCostDay = h
|
||||
}
|
||||
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
|
||||
highestRequestDay = h
|
||||
}
|
||||
}
|
||||
|
||||
actualDaysUsed := len(history)
|
||||
if actualDaysUsed == 0 {
|
||||
actualDaysUsed = 1
|
||||
}
|
||||
|
||||
// Get average duration
|
||||
var avgDuration struct {
|
||||
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
Scan(&avgDuration)
|
||||
|
||||
summary := AccountUsageSummary{
|
||||
Days: daysCount,
|
||||
ActualDaysUsed: actualDaysUsed,
|
||||
TotalCost: totalActualCost,
|
||||
TotalStandardCost: totalStandardCost,
|
||||
TotalRequests: totalRequests,
|
||||
TotalTokens: totalTokens,
|
||||
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
||||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||||
AvgDurationMs: avgDuration.AvgDurationMs,
|
||||
}
|
||||
|
||||
// Set today's stats
|
||||
todayStr := timezone.Now().Format("2006-01-02")
|
||||
for i := range history {
|
||||
if history[i].Date == todayStr {
|
||||
summary.Today = &struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}{
|
||||
Date: history[i].Date,
|
||||
Cost: history[i].ActualCost,
|
||||
Requests: history[i].Requests,
|
||||
Tokens: history[i].Tokens,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Set highest cost day
|
||||
if highestCostDay != nil {
|
||||
summary.HighestCostDay = &struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
}{
|
||||
Date: highestCostDay.Date,
|
||||
Label: highestCostDay.Label,
|
||||
Cost: highestCostDay.ActualCost,
|
||||
Requests: highestCostDay.Requests,
|
||||
}
|
||||
}
|
||||
|
||||
// Set highest request day
|
||||
if highestRequestDay != nil {
|
||||
summary.HighestRequestDay = &struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
}{
|
||||
Date: highestRequestDay.Date,
|
||||
Label: highestRequestDay.Label,
|
||||
Requests: highestRequestDay.Requests,
|
||||
Cost: highestRequestDay.ActualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// Get model statistics using the unified method
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
|
||||
return &AccountUsageStatsResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -66,17 +66,47 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("email ILIKE ?", searchPattern)
|
||||
db = db.Where(
|
||||
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
|
||||
searchPattern, searchPattern, searchPattern,
|
||||
)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Query users with pagination (reuse the same db with filters applied)
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Batch load subscriptions for all users (avoid N+1)
|
||||
if len(users) > 0 {
|
||||
userIDs := make([]int64, len(users))
|
||||
userMap := make(map[int64]*model.User, len(users))
|
||||
for i := range users {
|
||||
userIDs[i] = users[i].ID
|
||||
userMap[users[i].ID] = &users[i]
|
||||
}
|
||||
|
||||
// Query active subscriptions with groups in one query
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Associate subscriptions with users
|
||||
for i := range subscriptions {
|
||||
if user, ok := userMap[subscriptions[i].UserID]; ok {
|
||||
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
@@ -128,3 +158,16 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
||||
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
|
||||
Order("id ASC").
|
||||
First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
||||
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
||||
@@ -197,7 +197,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": 0,
|
||||
"daily_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
@@ -208,7 +208,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
|
||||
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"weekly_usage_usd": 0,
|
||||
"weekly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
@@ -219,7 +219,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
|
||||
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"monthly_usage_usd": 0,
|
||||
"monthly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
@@ -230,7 +230,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
||||
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"daily_window_start": activateTime,
|
||||
"weekly_window_start": activateTime,
|
||||
"monthly_window_start": activateTime,
|
||||
@@ -242,7 +242,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
|
||||
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
@@ -252,7 +252,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
|
||||
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"expires_at": newExpiresAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
@@ -262,7 +262,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
|
||||
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"notes": notes,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
@@ -281,7 +281,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
|
||||
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.SubscriptionStatusExpired,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
|
||||
@@ -36,7 +36,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyExitInfoProber,
|
||||
NewClaudeUsageFetcher,
|
||||
NewClaudeOAuthClient,
|
||||
NewClaudeUpstream,
|
||||
NewHTTPUpstream,
|
||||
NewOpenAIOAuthClient,
|
||||
|
||||
// Bind concrete repositories to service port interfaces
|
||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||
|
||||
@@ -82,6 +82,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
@@ -132,7 +133,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
|
||||
// 管理员接口
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
|
||||
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
|
||||
{
|
||||
// 仪表盘
|
||||
dashboard := admin.Group("/dashboard")
|
||||
@@ -192,7 +193,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
|
||||
// OAuth routes
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
@@ -201,6 +202,16 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
|
||||
// OpenAI OAuth routes
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
|
||||
// 代理管理
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
@@ -236,6 +247,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||
}
|
||||
|
||||
// 系统管理
|
||||
@@ -285,5 +300,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
@@ -17,27 +17,27 @@ var (
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Credentials *map[string]interface{} `json:"credentials"`
|
||||
Extra *map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name *string `json:"name"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
|
||||
@@ -14,7 +14,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -22,7 +24,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@@ -36,37 +40,46 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
claudeUpstream ClaudeUpstream
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() string {
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
hex64 := hex.EncodeToString(bytes)
|
||||
sessionUUID := uuid.New().String()
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||
}
|
||||
|
||||
// createTestPayload creates a Claude Code style test request payload
|
||||
func createTestPayload(modelID string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func createTestPayload(modelID string) (map[string]any, error) {
|
||||
sessionID, err := generateSessionString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"model": modelID,
|
||||
"messages": []map[string]interface{}{
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
@@ -77,7 +90,7 @@ func createTestPayload(modelID string) map[string]interface{} {
|
||||
},
|
||||
},
|
||||
},
|
||||
"system": []map[string]interface{}{
|
||||
"system": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
@@ -87,12 +100,12 @@ func createTestPayload(modelID string) map[string]interface{} {
|
||||
},
|
||||
},
|
||||
"metadata": map[string]string{
|
||||
"user_id": generateSessionString(),
|
||||
"user_id": sessionID,
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"temperature": 1,
|
||||
"stream": true,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
@@ -107,6 +120,18 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
// Route to platform-specific test method
|
||||
if account.IsOpenAI() {
|
||||
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testClaudeAccountConnection tests an Anthropic Claude account's connection
|
||||
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
@@ -116,7 +141,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping != nil && len(mapping) > 0 {
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
@@ -178,7 +203,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create Claude Code style payload (same for all account types)
|
||||
payload := createTestPayload(testModelID)
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
@@ -212,11 +240,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.claudeUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -224,11 +252,153 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processStream(c, resp.Body)
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// processStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = openai.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
var apiURL string
|
||||
var isOAuth bool
|
||||
var chatgptAccountID string
|
||||
|
||||
if account.IsOAuth() {
|
||||
isOAuth = true
|
||||
// OAuth - use Bearer token with ChatGPT internal API
|
||||
authToken = account.GetOpenAIAccessToken()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Check if token is expired and refresh if needed
|
||||
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
|
||||
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||
}
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
|
||||
// OAuth uses ChatGPT internal API
|
||||
apiURL = chatgptCodexAPIURL
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(baseURL, "/") + "/v1/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create OpenAI Responses API payload
|
||||
payload := createOpenAITestPayload(testModelID, isOAuth)
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
|
||||
// Set OAuth-specific headers for ChatGPT internal API
|
||||
if isOAuth {
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
payload := map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
// OAuth accounts using ChatGPT internal API require store: false and instructions
|
||||
if isOAuth {
|
||||
payload["store"] = false
|
||||
payload["instructions"] = openai.DefaultInstructions
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// processClaudeStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
@@ -252,7 +422,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -261,7 +431,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
|
||||
switch eventType {
|
||||
case "content_block_delta":
|
||||
if delta, ok := data["delta"].(map[string]interface{}); ok {
|
||||
if delta, ok := data["delta"].(map[string]any); ok {
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
@@ -271,7 +441,60 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]interface{}); ok {
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := data["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "response.output_text.delta":
|
||||
// OpenAI Responses API uses "delta" field for text content
|
||||
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||
}
|
||||
case "response.completed":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
@@ -284,7 +507,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||
log.Printf("failed to write SSE event: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
|
||||
@@ -70,16 +70,14 @@ type ClaudeUsageFetcher interface {
|
||||
type AccountUsageService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
oauthService *OAuthService
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
oauthService: oauthService,
|
||||
usageFetcher: usageFetcher,
|
||||
}
|
||||
}
|
||||
@@ -98,8 +96,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
if account.CanGetUsage() {
|
||||
// 检查缓存
|
||||
if cached, ok := usageCacheMap.Load(accountID); ok {
|
||||
cache := cached.(*usageCache)
|
||||
if time.Since(cache.timestamp) < cacheTTL {
|
||||
cache, ok := cached.(*usageCache)
|
||||
if !ok {
|
||||
usageCacheMap.Delete(accountID)
|
||||
} else if time.Since(cache.timestamp) < cacheTTL {
|
||||
return cache.data, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
@@ -21,9 +22,9 @@ type AdminService interface {
|
||||
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
||||
@@ -70,6 +71,9 @@ type AdminService interface {
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Wechat string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
@@ -78,6 +82,9 @@ type CreateUserInput struct {
|
||||
type UpdateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username *string
|
||||
Wechat *string
|
||||
Notes *string
|
||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
@@ -113,8 +120,8 @@ type CreateAccountInput struct {
|
||||
Name string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]interface{}
|
||||
Extra map[string]interface{}
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
@@ -124,8 +131,8 @@ type CreateAccountInput struct {
|
||||
type UpdateAccountInput struct {
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]interface{}
|
||||
Extra map[string]interface{}
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
@@ -192,8 +199,6 @@ type adminServiceImpl struct {
|
||||
proxyRepo ports.ProxyRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
redeemCodeRepo ports.RedeemCodeRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
}
|
||||
@@ -206,8 +211,6 @@ func NewAdminService(
|
||||
proxyRepo ports.ProxyRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
redeemCodeRepo ports.RedeemCodeRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
) AdminService {
|
||||
@@ -218,8 +221,6 @@ func NewAdminService(
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
}
|
||||
@@ -242,6 +243,9 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User,
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
|
||||
user := &model.User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Wechat: input.Wechat,
|
||||
Notes: input.Notes,
|
||||
Role: "user", // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
@@ -267,8 +271,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
return nil, errors.New("cannot disable admin user")
|
||||
}
|
||||
|
||||
// Track balance and concurrency changes for logging
|
||||
oldBalance := user.Balance
|
||||
oldConcurrency := user.Concurrency
|
||||
|
||||
if input.Email != "" {
|
||||
@@ -279,22 +281,25 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Role is not allowed to be changed via API to prevent privilege escalation
|
||||
|
||||
if input.Username != nil {
|
||||
user.Username = *input.Username
|
||||
}
|
||||
if input.Wechat != nil {
|
||||
user.Wechat = *input.Wechat
|
||||
}
|
||||
if input.Notes != nil {
|
||||
user.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
user.Status = input.Status
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 Balance(支持设置为 0)
|
||||
if input.Balance != nil {
|
||||
user.Balance = *input.Balance
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为任意值)
|
||||
if input.Concurrency != nil {
|
||||
user.Concurrency = *input.Concurrency
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 AllowedGroups
|
||||
if input.AllowedGroups != nil {
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
@@ -303,39 +308,15 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 余额变化时失效缓存
|
||||
if input.Balance != nil && *input.Balance != oldBalance {
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Create adjustment records for balance/concurrency changes
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
}
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminConcurrency,
|
||||
Value: float64(concurrencyDiff),
|
||||
Status: model.StatusUsed,
|
||||
@@ -344,8 +325,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,12 +344,14 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
return s.userRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldBalance := user.Balance
|
||||
|
||||
switch operation {
|
||||
case "set":
|
||||
user.Balance = balance
|
||||
@@ -379,19 +361,48 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
user.Balance -= balance
|
||||
}
|
||||
|
||||
if user.Balance < 0 {
|
||||
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
Notes: notes,
|
||||
}
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -404,9 +415,9 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
// Return mock data for now
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"period": period,
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
@@ -579,7 +590,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
for _, userID := range affectedUserIDs {
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -646,10 +659,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Type != "" {
|
||||
account.Type = input.Type
|
||||
}
|
||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = model.JSONB(input.Credentials)
|
||||
}
|
||||
if input.Extra != nil && len(input.Extra) > 0 {
|
||||
if len(input.Extra) > 0 {
|
||||
account.Extra = model.JSONB(input.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
@@ -831,8 +844,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
||||
|
||||
codes := make([]model.RedeemCode, 0, input.Count)
|
||||
for i := 0; i < input.Count; i++ {
|
||||
codeValue, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code := model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: codeValue,
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: model.StatusUnused,
|
||||
|
||||
@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
||||
return ErrApiKeyInvalidChars
|
||||
if (c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -23,6 +23,7 @@ var (
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrEmailVerifyRequired = errors.New("email verification is required")
|
||||
ErrRegDisabled = errors.New("registration is currently disabled")
|
||||
ErrServiceUnavailable = errors.New("service temporarily unavailable")
|
||||
)
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
@@ -90,7 +91,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("check email exists: %w", err)
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return "", nil, ErrEmailExists
|
||||
@@ -121,7 +123,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return "", nil, fmt.Errorf("create user: %w", err)
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 生成token
|
||||
@@ -148,7 +151,8 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check email exists: %w", err)
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return ErrEmailExists
|
||||
@@ -181,8 +185,8 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Error checking email exists: %v", err)
|
||||
return nil, fmt.Errorf("check email exists: %w", err)
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
log.Printf("[Auth] Email already exists: %s", email)
|
||||
@@ -254,7 +258,9 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
return "", nil, fmt.Errorf("get user by email: %w", err)
|
||||
// 记录数据库错误但不暴露给用户
|
||||
log.Printf("[Auth] Database error during login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
@@ -278,7 +284,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
@@ -354,7 +360,8 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
return "", fmt.Errorf("get user: %w", err)
|
||||
log.Printf("[Auth] Database error refreshing token: %v", err)
|
||||
return "", ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
|
||||
@@ -259,11 +259,11 @@ func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, es
|
||||
}
|
||||
|
||||
// GetPricingServiceStatus 获取价格服务状态
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]any {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.GetStatus()
|
||||
}
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"model_count": len(s.fallbackPrices),
|
||||
"last_updated": "using fallback",
|
||||
"local_hash": "N/A",
|
||||
|
||||
@@ -9,12 +9,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
// Default max wait time
|
||||
defaultMaxWait = 60 * time.Second
|
||||
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
type AcquireResult struct {
|
||||
Acquired bool
|
||||
Acquired bool
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls dial: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new smtp client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth: %w", err)
|
||||
@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp connection failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
|
||||
@@ -24,11 +24,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ClaudeUpstream handles HTTP requests to Claude API
|
||||
type ClaudeUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
|
||||
const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
@@ -53,7 +48,6 @@ var allowedHeaders = map[string]bool{
|
||||
"anthropic-beta": true,
|
||||
"accept-language": true,
|
||||
"sec-fetch-mode": true,
|
||||
"accept-encoding": true,
|
||||
"user-agent": true,
|
||||
"content-type": true,
|
||||
}
|
||||
@@ -84,12 +78,11 @@ type GatewayService struct {
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
claudeUpstream ClaudeUpstream
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -100,12 +93,11 @@ func NewGatewayService(
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
claudeUpstream ClaudeUpstream,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -114,24 +106,23 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash 从请求体计算粘性会话hash
|
||||
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
var req map[string]interface{}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
||||
if metadata, ok := req["metadata"].(map[string]interface{}); ok {
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
||||
@@ -155,8 +146,8 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
}
|
||||
|
||||
// 4. 最后fallback: 使用第一条消息
|
||||
if messages, ok := req["messages"].([]interface{}); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
|
||||
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
@@ -167,14 +158,14 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]interface{}) string {
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
var content string
|
||||
|
||||
// 检查system中的cacheable内容
|
||||
if system, ok := req["system"].([]interface{}); ok {
|
||||
if system, ok := req["system"].([]any); ok {
|
||||
for _, part := range system {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
content += text
|
||||
@@ -186,13 +177,13 @@ func (s *GatewayService) extractCacheableContent(req map[string]interface{}) str
|
||||
}
|
||||
|
||||
// 检查messages中的cacheable内容
|
||||
if messages, ok := req["messages"].([]interface{}); ok {
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]interface{}); ok {
|
||||
if msgContent, ok := msgMap["content"].([]interface{}); ok {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
// 找到cacheable内容,提取第一条消息的文本
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
@@ -208,14 +199,14 @@ func (s *GatewayService) extractCacheableContent(req map[string]interface{}) str
|
||||
return content
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromSystem(system interface{}) string {
|
||||
func (s *GatewayService) extractTextFromSystem(system any) string {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
case []any:
|
||||
var texts []string
|
||||
for _, part := range v {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
@@ -226,14 +217,14 @@ func (s *GatewayService) extractTextFromSystem(system interface{}) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromContent(content interface{}) string {
|
||||
func (s *GatewayService) extractTextFromContent(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
case []any:
|
||||
var texts []string
|
||||
for _, part := range v {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if partMap["type"] == "text" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
@@ -253,7 +244,7 @@ func (s *GatewayService) hashContent(content string) string {
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||
var req map[string]interface{}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
@@ -281,19 +272,21 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号)
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -331,7 +324,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -407,11 +402,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -481,7 +476,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -502,8 +497,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// 确保必要的headers存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
@@ -557,7 +552,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
|
||||
// 客户端没传,根据模型生成
|
||||
var modelID string
|
||||
var reqMap map[string]interface{}
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(body, &reqMap) == nil {
|
||||
if m, ok := reqMap["model"].(string); ok {
|
||||
modelID = m
|
||||
@@ -678,7 +673,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
// 转发行
|
||||
fmt.Fprintf(w, "%s\n", line)
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 解析usage数据
|
||||
@@ -707,7 +704,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
@@ -717,7 +714,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
return line
|
||||
}
|
||||
|
||||
msg, ok := event["message"].(map[string]interface{})
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
if !ok {
|
||||
return line
|
||||
}
|
||||
@@ -737,7 +734,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
// 解析message_start获取input tokens
|
||||
// 解析message_start获取input tokens(标准Claude API格式)
|
||||
var msgStart struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
@@ -750,15 +747,30 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
// 解析message_delta获取output tokens
|
||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||
var msgDelta struct {
|
||||
Type string `json:"type"`
|
||||
Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
||||
// output_tokens 总是从 message_delta 获取
|
||||
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
||||
|
||||
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
|
||||
if usage.InputTokens == 0 {
|
||||
usage.InputTokens = msgDelta.Usage.InputTokens
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 {
|
||||
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if usage.CacheReadInputTokens == 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -799,7 +811,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
|
||||
// replaceModelInResponseBody 替换响应体中的model字段
|
||||
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]interface{}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
}
|
||||
@@ -980,12 +992,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
@@ -1045,7 +1059,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -1069,8 +1083,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// 确保必要的 headers 存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
@@ -167,7 +167,7 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
// GetStats 获取分组统计信息
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]interface{}, error) {
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -182,7 +182,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]inter
|
||||
return nil, fmt.Errorf("get account count: %w", err)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
stats := map[string]any{
|
||||
"id": group.ID,
|
||||
"name": group.Name,
|
||||
"rate_multiplier": group.RateMultiplier,
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
@@ -115,12 +114,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerpr
|
||||
return
|
||||
}
|
||||
|
||||
// 设置User-Agent
|
||||
// 设置user-agent
|
||||
if fp.UserAgent != "" {
|
||||
req.Header.Set("User-Agent", fp.UserAgent)
|
||||
req.Header.Set("user-agent", fp.UserAgent)
|
||||
}
|
||||
|
||||
// 设置x-stainless-*头(使用正确的大小写)
|
||||
// 设置x-stainless-*头
|
||||
if fp.StainlessLang != "" {
|
||||
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
|
||||
}
|
||||
@@ -150,12 +149,12 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var reqMap map[string]interface{}
|
||||
var reqMap map[string]any
|
||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]interface{})
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -284,3 +284,8 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
836
backend/internal/service/openai_gateway_service.go
Normal file
836
backend/internal/service/openai_gateway_service.go
Normal file
@@ -0,0 +1,836 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChatGPT internal API for OAuth accounts
|
||||
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
// OpenAI Platform API for API Key accounts (fallback)
|
||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"content-type": true,
|
||||
"user-agent": true,
|
||||
"originator": true,
|
||||
"session_id": true,
|
||||
}
|
||||
|
||||
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
|
||||
type OpenAICodexUsageSnapshot struct {
|
||||
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
|
||||
PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"`
|
||||
PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"`
|
||||
SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"`
|
||||
SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"`
|
||||
SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"`
|
||||
PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIUsage represents OpenAI API response usage
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
) *OpenAIGatewayService {
|
||||
return &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
sessionID := c.GetHeader("session_id")
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
}
|
||||
|
||||
// SelectAccountForModel selects an account supporting the requested model
|
||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *model.Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// Lower priority value means higher priority
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
|
||||
}
|
||||
return nil, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// 4. Set sticky session
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// GetAccessToken gets the access token for an OpenAI account
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth:
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case model.AccountTypeApiKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
// Apply model mapping
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API, add store: false
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
var err error
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize request body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case model.AccountTypeApiKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
targetURL = baseURL + "/v1/responses"
|
||||
} else {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
}
|
||||
default:
|
||||
targetURL = openaiPlatformAPIURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set authentication header
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
// Set headers specific to OAuth accounts (ChatGPT internal API)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
|
||||
req.Host = "chatgpt.com"
|
||||
// Required: set chatgpt-account-id header
|
||||
chatgptAccountID := account.GetChatGPTAccountID()
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
// Set accept header based on stream mode
|
||||
if isStream {
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("accept", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// Whitelist passthrough headers
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom User-Agent if configured
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// Check custom error codes
|
||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream gateway error",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
|
||||
// Return appropriate error response
|
||||
var errType, errMsg string
|
||||
var statusCode int
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
statusCode = http.StatusTooManyRequests
|
||||
errType = "rate_limit_error"
|
||||
errMsg = "Upstream rate limit exceeded, please retry later"
|
||||
default:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream request failed"
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": errMsg,
|
||||
},
|
||||
})
|
||||
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// openaiStreamingResult streaming response result
|
||||
type openaiStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
// Set SSE response headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// Pass through other headers
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace && strings.HasPrefix(line, "data: ") {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Parse usage data
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := line[6:]
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
data := line[6:]
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
}
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
||||
usage.InputTokens = event.Response.Usage.InputTokens
|
||||
usage.OutputTokens = event.Response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse usage
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
|
||||
}
|
||||
|
||||
// Replace model in response if needed
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Pass through headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *model.ApiKey
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
Subscription *model.UserSubscription
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 计算实际的新输入token(减去缓存读取的token)
|
||||
// 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费
|
||||
actualInputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens
|
||||
if actualInputTokens < 0 {
|
||||
actualInputTokens = 0
|
||||
}
|
||||
|
||||
// Calculate cost
|
||||
tokens := UsageTokens{
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
|
||||
// Get rate multiplier
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
|
||||
// Determine billing type
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := model.BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = model.BillingTypeSubscription
|
||||
}
|
||||
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Update account last used
|
||||
_ = s.accountRepo.UpdateLastUsed(ctx, account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
||||
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
snapshot := &OpenAICodexUsageSnapshot{}
|
||||
hasData := false
|
||||
|
||||
// Helper to parse float64 from header
|
||||
parseFloat := func(key string) *float64 {
|
||||
if v := headers.Get(key); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return &f
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper to parse int from header
|
||||
parseInt := func(key string) *int {
|
||||
if v := headers.Get(key); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return &i
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Primary (weekly) limits
|
||||
if v := parseFloat("x-codex-primary-used-percent"); v != nil {
|
||||
snapshot.PrimaryUsedPercent = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil {
|
||||
snapshot.PrimaryResetAfterSeconds = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-primary-window-minutes"); v != nil {
|
||||
snapshot.PrimaryWindowMinutes = v
|
||||
hasData = true
|
||||
}
|
||||
|
||||
// Secondary (5h) limits
|
||||
if v := parseFloat("x-codex-secondary-used-percent"); v != nil {
|
||||
snapshot.SecondaryUsedPercent = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil {
|
||||
snapshot.SecondaryResetAfterSeconds = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-secondary-window-minutes"); v != nil {
|
||||
snapshot.SecondaryWindowMinutes = v
|
||||
hasData = true
|
||||
}
|
||||
|
||||
// Overflow ratio
|
||||
if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil {
|
||||
snapshot.PrimaryOverSecondaryPercent = v
|
||||
hasData = true
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return nil
|
||||
}
|
||||
|
||||
snapshot.UpdatedAt = time.Now().Format(time.RFC3339)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
||||
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
if snapshot.PrimaryOverSecondaryPercent != nil {
|
||||
updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent
|
||||
}
|
||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||
|
||||
// Update account's Extra field asynchronously
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}()
|
||||
}
|
||||
257
backend/internal/service/openai_oauth_service.go
Normal file
257
backend/internal/service/openai_oauth_service.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthClient ports.OpenAIOAuthClient
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
||||
func NewOpenAIOAuthService(proxyRepo ports.ProxyRepository, oauthClient ports.OpenAIOAuthClient) *OpenAIOAuthService {
|
||||
return &OpenAIOAuthService{
|
||||
sessionStore: openai.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIAuthURLResult contains the authorization URL and session info
|
||||
type OpenAIAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate session ID
|
||||
sessionID, err := openai.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use default redirect URI if not specified
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeInput represents the input for code exchange
|
||||
type OpenAIExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// OpenAITokenInfo represents the token information for OpenAI
|
||||
type OpenAITokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
OrganizationID string `json:"organization_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use redirect URI from session or input
|
||||
redirectURI := session.RedirectURI
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
// Delete session after successful exchange
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials builds credentials map from token info
|
||||
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
creds["id_token"] = tokenInfo.IDToken
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ChatGPTAccountID != "" {
|
||||
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
|
||||
}
|
||||
if tokenInfo.ChatGPTUserID != "" {
|
||||
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
|
||||
}
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OpenAIOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
@@ -27,9 +27,12 @@ type AccountRepository interface {
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
}
|
||||
|
||||
9
backend/internal/service/ports/http_upstream.go
Normal file
9
backend/internal/service/ports/http_upstream.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package ports
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPUpstream interface for making HTTP requests to upstream APIs (Claude, OpenAI, etc.)
|
||||
// This is a generic interface that can be used for any HTTP-based upstream service.
|
||||
type HTTPUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
13
backend/internal/service/ports/openai_oauth.go
Normal file
13
backend/internal/service/ports/openai_oauth.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// OpenAIOAuthClient interface for OpenAI OAuth operations
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
}
|
||||
@@ -11,6 +11,7 @@ type UserRepository interface {
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*model.User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
Update(ctx context.Context, user *model.User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
|
||||
@@ -9,11 +9,13 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
@@ -419,8 +421,17 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 基于模型系列匹配
|
||||
return s.matchByModelFamily(modelLower)
|
||||
// 4. 基于模型系列匹配(Claude)
|
||||
if pricing := s.matchByModelFamily(modelLower); pricing != nil {
|
||||
return pricing
|
||||
}
|
||||
|
||||
// 5. OpenAI 模型回退策略
|
||||
if strings.HasPrefix(modelLower, "gpt-") {
|
||||
return s.matchOpenAIModel(modelLower)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractBaseName 提取基础模型名称(去掉日期版本号)
|
||||
@@ -514,12 +525,76 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchOpenAIModel OpenAI 模型回退匹配策略
|
||||
// 回退顺序:
|
||||
// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等)
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 正则匹配日期后缀 (如 -20251222)
|
||||
datePattern := regexp.MustCompile(`-\d{8}$`)
|
||||
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, datePattern)
|
||||
|
||||
for _, variant := range variants {
|
||||
if pricing, ok := s.pricingData[variant]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, variant)
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel)
|
||||
return pricing
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateOpenAIModelVariants 生成 OpenAI 模型的回退变体列表
|
||||
func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *regexp.Regexp) []string {
|
||||
seen := make(map[string]bool)
|
||||
var variants []string
|
||||
|
||||
addVariant := func(v string) {
|
||||
if v != model && !seen[v] {
|
||||
seen[v] = true
|
||||
variants = append(variants, v)
|
||||
}
|
||||
}
|
||||
|
||||
// 1. 去掉日期版本号: gpt-5.2-20251222 -> gpt-5.2
|
||||
withoutDate := datePattern.ReplaceAllString(model, "")
|
||||
if withoutDate != model {
|
||||
addVariant(withoutDate)
|
||||
}
|
||||
|
||||
// 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2
|
||||
// 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的
|
||||
basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
|
||||
// 3. 同时去掉日期后再提取基础版本号
|
||||
if withoutDate != model {
|
||||
if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
}
|
||||
|
||||
return variants
|
||||
}
|
||||
|
||||
// GetStatus 获取服务状态
|
||||
func (s *PricingService) GetStatus() map[string]interface{} {
|
||||
func (s *PricingService) GetStatus() map[string]any {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"model_count": len(s.pricingData),
|
||||
"last_updated": s.lastUpdated,
|
||||
"local_hash": s.localHash[:min(8, len(s.localHash))],
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -359,12 +359,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
// GetStats 获取兑换码统计信息
|
||||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
|
||||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
|
||||
// TODO: 实现统计逻辑
|
||||
// 统计未使用、已使用的兑换码数量
|
||||
// 统计总面值等
|
||||
|
||||
stats := map[string]interface{}{
|
||||
stats := map[string]any{
|
||||
"total_codes": 0,
|
||||
"unused_codes": 0,
|
||||
"used_codes": 0,
|
||||
|
||||
@@ -2,30 +2,32 @@ package service
|
||||
|
||||
// Services 服务集合容器
|
||||
type Services struct {
|
||||
Auth *AuthService
|
||||
User *UserService
|
||||
ApiKey *ApiKeyService
|
||||
Group *GroupService
|
||||
Account *AccountService
|
||||
Proxy *ProxyService
|
||||
Redeem *RedeemService
|
||||
Usage *UsageService
|
||||
Pricing *PricingService
|
||||
Billing *BillingService
|
||||
BillingCache *BillingCacheService
|
||||
Admin AdminService
|
||||
Gateway *GatewayService
|
||||
OAuth *OAuthService
|
||||
RateLimit *RateLimitService
|
||||
AccountUsage *AccountUsageService
|
||||
AccountTest *AccountTestService
|
||||
Setting *SettingService
|
||||
Email *EmailService
|
||||
EmailQueue *EmailQueueService
|
||||
Turnstile *TurnstileService
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
Auth *AuthService
|
||||
User *UserService
|
||||
ApiKey *ApiKeyService
|
||||
Group *GroupService
|
||||
Account *AccountService
|
||||
Proxy *ProxyService
|
||||
Redeem *RedeemService
|
||||
Usage *UsageService
|
||||
Pricing *PricingService
|
||||
Billing *BillingService
|
||||
BillingCache *BillingCacheService
|
||||
Admin AdminService
|
||||
Gateway *GatewayService
|
||||
OpenAIGateway *OpenAIGatewayService
|
||||
OAuth *OAuthService
|
||||
OpenAIOAuth *OpenAIOAuthService
|
||||
RateLimit *RateLimitService
|
||||
AccountUsage *AccountUsageService
|
||||
AccountTest *AccountTestService
|
||||
Setting *SettingService
|
||||
Email *EmailService
|
||||
EmailQueue *EmailQueueService
|
||||
Turnstile *TurnstileService
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@@ -262,3 +264,63 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminApiKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, err
|
||||
}
|
||||
if key == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// 脱敏:显示前 10 位和后 4 位
|
||||
if len(key) > 14 {
|
||||
maskedKey = key[:10] + "..." + key[len(key)-4:]
|
||||
} else {
|
||||
maskedKey = key
|
||||
}
|
||||
|
||||
return maskedKey, true, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
}
|
||||
return "", err // 数据库错误
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -334,24 +335,67 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.userSubRepo.ListByUserID(ctx, userID)
|
||||
subs, err := s.userSubRepo.ListByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
// ListActiveUserSubscriptions 获取用户的所有有效订阅
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
// ListGroupSubscriptions 获取分组的所有订阅
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.ListByGroupID(ctx, groupID, params)
|
||||
subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
return subs, pag, nil
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
return subs, pag, nil
|
||||
}
|
||||
|
||||
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
|
||||
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
|
||||
func normalizeExpiredWindows(subs []model.UserSubscription) {
|
||||
for i := range subs {
|
||||
sub := &subs[i]
|
||||
// 日窗口过期:清零展示数据
|
||||
if sub.NeedsDailyReset() {
|
||||
sub.DailyWindowStart = nil
|
||||
sub.DailyUsageUSD = 0
|
||||
}
|
||||
// 周窗口过期:清零展示数据
|
||||
if sub.NeedsWeeklyReset() {
|
||||
sub.WeeklyWindowStart = nil
|
||||
sub.WeeklyUsageUSD = 0
|
||||
}
|
||||
// 月窗口过期:清零展示数据
|
||||
if sub.NeedsMonthlyReset() {
|
||||
sub.MonthlyWindowStart = nil
|
||||
sub.MonthlyUsageUSD = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startOfDay 返回给定时间所在日期的零点(保持原时区)
|
||||
|
||||
@@ -27,6 +27,7 @@ type TokenRefreshService struct {
|
||||
func NewTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
@@ -38,9 +39,7 @@ func NewTokenRefreshService(
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
// 未来可以添加其他平台的刷新器:
|
||||
// NewOpenAITokenRefresher(...),
|
||||
// NewGeminiTokenRefresher(...),
|
||||
NewOpenAITokenRefresher(openaiOAuthService),
|
||||
}
|
||||
|
||||
return s
|
||||
|
||||
@@ -19,7 +19,7 @@ type TokenRefresher interface {
|
||||
|
||||
// Refresh 执行token刷新,返回更新后的credentials
|
||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]any, error)
|
||||
}
|
||||
|
||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||
@@ -61,14 +61,14 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留现有credentials中的所有字段
|
||||
newCredentials := make(map[string]interface{})
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
@@ -88,3 +88,54 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Accou
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
// OpenAITokenRefresher 处理 OpenAI OAuth token刷新
|
||||
type OpenAITokenRefresher struct {
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||
func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAITokenRefresher {
|
||||
return &OpenAITokenRefresher{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformOpenAI &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
expiresAt := account.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Until(*expiresAt) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用服务提供的方法构建新凭证,并保留原有字段
|
||||
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// 保留原有credentials中非token相关字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ var (
|
||||
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
type TurnstileVerifier interface {
|
||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
// Download archive
|
||||
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
||||
@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
backupPath := exePath + ".backup"
|
||||
|
||||
// Remove old backup if exists
|
||||
os.Remove(backupPath)
|
||||
_ = os.Remove(backupPath)
|
||||
|
||||
// Step 1: Move current binary to backup
|
||||
if err := os.Rename(exePath, backupPath); err != nil {
|
||||
@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
var reader io.Reader = f
|
||||
|
||||
@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzr.Close()
|
||||
defer func() { _ = gzr.Close() }()
|
||||
reader = gzr
|
||||
}
|
||||
|
||||
@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
// Use LimitReader to prevent decompression bombs
|
||||
limited := io.LimitReader(tr, maxBinarySize)
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
out.Close()
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
limited := io.LimitReader(reader, maxBinarySize)
|
||||
_, err = io.Copy(out, limited)
|
||||
return err
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
return out.Close()
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
_ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int {
|
||||
parts := strings.Split(v, ".")
|
||||
result := [3]int{0, 0, 0}
|
||||
for i := 0; i < len(parts) && i < 3; i++ {
|
||||
fmt.Sscanf(parts[i], "%d", &result[i])
|
||||
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||
result[i] = parsed
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, st
|
||||
}
|
||||
|
||||
// GetDailyStats 获取每日使用统计(最近N天)
|
||||
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
|
||||
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]any, error) {
|
||||
endTime := time.Now()
|
||||
startTime := endTime.AddDate(0, 0, -days)
|
||||
|
||||
@@ -227,13 +227,13 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
}
|
||||
|
||||
// 计算平均值并转换为数组
|
||||
result := make([]map[string]interface{}, 0, len(dailyStats))
|
||||
result := make([]map[string]any, 0, len(dailyStats))
|
||||
for date, stats := range dailyStats {
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
@@ -22,6 +21,8 @@ var (
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
type UpdateProfileRequest struct {
|
||||
Email *string `json:"email"`
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
}
|
||||
|
||||
@@ -34,14 +35,12 @@ type ChangePasswordRequest struct {
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
|
||||
func NewUserService(userRepo ports.UserRepository) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +79,14 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
user.Email = *req.Email
|
||||
}
|
||||
|
||||
if req.Username != nil {
|
||||
user.Username = *req.Username
|
||||
}
|
||||
|
||||
if req.Wechat != nil {
|
||||
user.Wechat = *req.Wechat
|
||||
}
|
||||
|
||||
if req.Concurrency != nil {
|
||||
user.Concurrency = *req.Concurrency
|
||||
}
|
||||
|
||||
@@ -37,9 +37,10 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@@ -60,7 +61,9 @@ var ProviderSet = wire.NewSet(
|
||||
NewBillingCacheService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewRateLimitService,
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
|
||||
@@ -352,4 +352,3 @@ func install(c *gin.Context) {
|
||||
"restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config paths
|
||||
@@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get db instance: %w", err)
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
defer func() {
|
||||
if sqlDB == nil {
|
||||
return
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
}
|
||||
|
||||
// Now connect to the target database to verify
|
||||
sqlDB.Close()
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
sqlDB = nil
|
||||
|
||||
targetDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
@@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target db instance: %w", err)
|
||||
}
|
||||
defer targetSqlDB.Close()
|
||||
defer func() {
|
||||
if err := targetSqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
@@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error {
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
defer rdb.Close()
|
||||
defer func() {
|
||||
if err := rdb.Close(); err != nil {
|
||||
log.Printf("failed to close redis client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error {
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
cfg.JWT.Secret = generateSecret(32)
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
}
|
||||
|
||||
// Test connections
|
||||
@@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
defer func() {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 使用 model 包的 AutoMigrate,确保模型定义统一
|
||||
return model.AutoMigrate(db)
|
||||
@@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
defer func() {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if admin already exists
|
||||
var count int64
|
||||
@@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
return os.WriteFile(ConfigFile, data, 0600)
|
||||
}
|
||||
|
||||
func generateSecret(length int) string {
|
||||
func generateSecret(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
@@ -431,13 +463,21 @@ func AutoSetupFromEnv() error {
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
cfg.JWT.Secret = generateSecret(32)
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Generated JWT secret automatically")
|
||||
}
|
||||
|
||||
// Generate admin password if not provided
|
||||
if cfg.Admin.Password == "" {
|
||||
cfg.Admin.Password = generateSecret(16)
|
||||
password, err := generateSecret(16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate admin password: %w", err)
|
||||
}
|
||||
cfg.Admin.Password = password
|
||||
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
||||
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
}
|
||||
|
||||
20
backend/internal/web/embed_off.go
Normal file
20
backend/internal/web/embed_off.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build !embed
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func HasEmbeddedFrontend() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build embed
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
@@ -13,8 +15,6 @@ import (
|
||||
//go:embed all:dist
|
||||
var frontendFS embed.FS
|
||||
|
||||
// ServeEmbeddedFrontend returns a Gin handler that serves embedded frontend assets
|
||||
// and handles SPA routing by falling back to index.html for non-API routes.
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
distFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
@@ -25,7 +25,6 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Skip API and gateway routes
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/setup/") ||
|
||||
@@ -34,20 +33,18 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to serve static file
|
||||
cleanPath := strings.TrimPrefix(path, "/")
|
||||
if cleanPath == "" {
|
||||
cleanPath = "index.html"
|
||||
}
|
||||
|
||||
if file, err := distFS.Open(cleanPath); err == nil {
|
||||
file.Close()
|
||||
_ = file.Close()
|
||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// SPA fallback: serve index.html for all other routes
|
||||
serveIndexHTML(c, distFS)
|
||||
}
|
||||
}
|
||||
@@ -59,7 +56,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
@@ -72,7 +69,6 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// HasEmbeddedFrontend checks if frontend assets are embedded
|
||||
func HasEmbeddedFrontend() bool {
|
||||
_, err := frontendFS.ReadFile("dist/index.html")
|
||||
return err == nil
|
||||
6
backend/migrations/004_add_redeem_code_notes.sql
Normal file
6
backend/migrations/004_add_redeem_code_notes.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
-- 为 redeem_codes 表添加备注字段
|
||||
|
||||
ALTER TABLE redeem_codes
|
||||
ADD COLUMN IF NOT EXISTS notes TEXT DEFAULT NULL;
|
||||
|
||||
COMMENT ON COLUMN redeem_codes.notes IS '备注说明(管理员调整时的原因说明)';
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
AccountUsageInfo,
|
||||
WindowStats,
|
||||
ClaudeModel,
|
||||
AccountUsageStatsResponse,
|
||||
} from '@/types';
|
||||
|
||||
/**
|
||||
@@ -126,27 +127,12 @@ export async function refreshCredentials(id: number): Promise<Account> {
|
||||
/**
|
||||
* Get account usage statistics
|
||||
* @param id - Account ID
|
||||
* @param period - Time period
|
||||
* @returns Account usage statistics
|
||||
* @param days - Number of days (default: 30)
|
||||
* @returns Account usage statistics with history, summary, and models
|
||||
*/
|
||||
export async function getStats(
|
||||
id: number,
|
||||
period: string = 'month'
|
||||
): Promise<{
|
||||
total_requests: number;
|
||||
successful_requests: number;
|
||||
failed_requests: number;
|
||||
total_tokens: number;
|
||||
average_response_time: number;
|
||||
}> {
|
||||
const { data } = await apiClient.get<{
|
||||
total_requests: number;
|
||||
successful_requests: number;
|
||||
failed_requests: number;
|
||||
total_tokens: number;
|
||||
average_response_time: number;
|
||||
}>(`/admin/accounts/${id}/stats`, {
|
||||
params: { period },
|
||||
export async function getStats(id: number, days: number = 30): Promise<AccountUsageStatsResponse> {
|
||||
const { data } = await apiClient.get<AccountUsageStatsResponse>(`/admin/accounts/${id}/stats`, {
|
||||
params: { days },
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -99,11 +99,49 @@ export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ me
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Admin API Key status response
|
||||
*/
|
||||
export interface AdminApiKeyStatus {
|
||||
exists: boolean;
|
||||
masked_key: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get admin API key status
|
||||
* @returns Status indicating if key exists and masked version
|
||||
*/
|
||||
export async function getAdminApiKey(): Promise<AdminApiKeyStatus> {
|
||||
const { data } = await apiClient.get<AdminApiKeyStatus>('/admin/settings/admin-api-key');
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate admin API key
|
||||
* @returns The new full API key (only shown once)
|
||||
*/
|
||||
export async function regenerateAdminApiKey(): Promise<{ key: string }> {
|
||||
const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate');
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete admin API key
|
||||
* @returns Success message
|
||||
*/
|
||||
export async function deleteAdminApiKey(): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key');
|
||||
return data;
|
||||
}
|
||||
|
||||
export const settingsAPI = {
|
||||
getSettings,
|
||||
updateSettings,
|
||||
testSmtpConnection,
|
||||
sendTestEmail,
|
||||
getAdminApiKey,
|
||||
regenerateAdminApiKey,
|
||||
deleteAdminApiKey,
|
||||
};
|
||||
|
||||
export default settingsAPI;
|
||||
|
||||
@@ -84,16 +84,19 @@ export async function deleteUser(id: number): Promise<{ message: string }> {
|
||||
* @param id - User ID
|
||||
* @param balance - New balance
|
||||
* @param operation - Operation type ('set', 'add', 'subtract')
|
||||
* @param notes - Optional notes for the balance adjustment
|
||||
* @returns Updated user
|
||||
*/
|
||||
export async function updateBalance(
|
||||
id: number,
|
||||
balance: number,
|
||||
operation: 'set' | 'add' | 'subtract' = 'set'
|
||||
operation: 'set' | 'add' | 'subtract' = 'set',
|
||||
notes?: string
|
||||
): Promise<User> {
|
||||
const { data } = await apiClient.post<User>(`/admin/users/${id}/balance`, {
|
||||
balance,
|
||||
operation,
|
||||
notes: notes || '',
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,20 @@ import type { User, ChangePasswordRequest } from '@/types';
|
||||
* @returns User profile data
|
||||
*/
|
||||
export async function getProfile(): Promise<User> {
|
||||
const { data } = await apiClient.get<User>('/users/me');
|
||||
const { data } = await apiClient.get<User>('/user/profile');
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update current user profile
|
||||
* @param profile - Profile data to update
|
||||
* @returns Updated user profile data
|
||||
*/
|
||||
export async function updateProfile(profile: {
|
||||
username?: string;
|
||||
wechat?: string;
|
||||
}): Promise<User> {
|
||||
const { data } = await apiClient.put<User>('/user', profile);
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -29,12 +42,13 @@ export async function changePassword(
|
||||
new_password: newPassword,
|
||||
};
|
||||
|
||||
const { data } = await apiClient.post<{ message: string }>('/users/me/password', payload);
|
||||
const { data } = await apiClient.put<{ message: string }>('/user/password', payload);
|
||||
return data;
|
||||
}
|
||||
|
||||
export const userAPI = {
|
||||
getProfile,
|
||||
updateProfile,
|
||||
changePassword,
|
||||
};
|
||||
|
||||
|
||||
546
frontend/src/components/account/AccountStatsModal.vue
Normal file
546
frontend/src/components/account/AccountStatsModal.vue
Normal file
@@ -0,0 +1,546 @@
|
||||
<template>
|
||||
<Modal
|
||||
:show="show"
|
||||
:title="t('admin.accounts.usageStatistics')"
|
||||
size="2xl"
|
||||
@close="handleClose"
|
||||
>
|
||||
<div class="space-y-6">
|
||||
<!-- Account Info Header -->
|
||||
<div v-if="account" class="flex items-center justify-between p-3 bg-gradient-to-r from-primary-50 to-primary-100 dark:from-primary-900/20 dark:to-primary-800/20 rounded-xl border border-primary-200 dark:border-primary-700/50">
|
||||
<div class="flex items-center gap-3">
|
||||
<div class="w-10 h-10 rounded-lg bg-gradient-to-br from-primary-500 to-primary-600 flex items-center justify-center">
|
||||
<svg class="w-5 h-5 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<div class="font-semibold text-gray-900 dark:text-gray-100">{{ account.name }}</div>
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.last30DaysUsage') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
:class="[
|
||||
'px-2.5 py-1 text-xs font-semibold rounded-full',
|
||||
account.status === 'active'
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-500/20 dark:text-green-400'
|
||||
: 'bg-gray-100 text-gray-600 dark:bg-gray-700 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
{{ account.status }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div v-if="loading" class="flex items-center justify-center py-12">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
|
||||
<template v-else-if="stats">
|
||||
<!-- Row 1: Main Stats Cards -->
|
||||
<div class="grid grid-cols-2 gap-4 lg:grid-cols-4">
|
||||
<!-- 30-Day Total Cost -->
|
||||
<div class="card p-4 bg-gradient-to-br from-emerald-50 to-white dark:from-emerald-900/10 dark:to-dark-700 border-emerald-200 dark:border-emerald-800/30">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.totalCost') }}</span>
|
||||
<div class="p-1.5 rounded-lg bg-emerald-100 dark:bg-emerald-900/30">
|
||||
<svg class="w-4 h-4 text-emerald-600 dark:text-emerald-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<p class="text-2xl font-bold text-gray-900 dark:text-white">${{ formatCost(stats.summary.total_cost) }}</p>
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">
|
||||
{{ t('admin.accounts.stats.accumulatedCost') }}
|
||||
<span class="text-gray-400 dark:text-gray-500">({{ t('admin.accounts.stats.standardCost') }}: ${{ formatCost(stats.summary.total_standard_cost) }})</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 30-Day Total Requests -->
|
||||
<div class="card p-4 bg-gradient-to-br from-blue-50 to-white dark:from-blue-900/10 dark:to-dark-700 border-blue-200 dark:border-blue-800/30">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.totalRequests') }}</span>
|
||||
<div class="p-1.5 rounded-lg bg-blue-100 dark:bg-blue-900/30">
|
||||
<svg class="w-4 h-4 text-blue-600 dark:text-blue-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<p class="text-2xl font-bold text-gray-900 dark:text-white">{{ formatNumber(stats.summary.total_requests) }}</p>
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.stats.totalCalls') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Daily Average Cost -->
|
||||
<div class="card p-4 bg-gradient-to-br from-amber-50 to-white dark:from-amber-900/10 dark:to-dark-700 border-amber-200 dark:border-amber-800/30">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.avgDailyCost') }}</span>
|
||||
<div class="p-1.5 rounded-lg bg-amber-100 dark:bg-amber-900/30">
|
||||
<svg class="w-4 h-4 text-amber-600 dark:text-amber-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 7h6m0 10v-3m-3 3h.01M9 17h.01M9 14h.01M12 14h.01M15 11h.01M12 11h.01M9 11h.01M7 21h10a2 2 0 002-2V5a2 2 0 00-2-2H7a2 2 0 00-2 2v14a2 2 0 002 2z" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<p class="text-2xl font-bold text-gray-900 dark:text-white">${{ formatCost(stats.summary.avg_daily_cost) }}</p>
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.stats.basedOnActualDays', { days: stats.summary.actual_days_used }) }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Daily Average Requests -->
|
||||
<div class="card p-4 bg-gradient-to-br from-purple-50 to-white dark:from-purple-900/10 dark:to-dark-700 border-purple-200 dark:border-purple-800/30">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.avgDailyRequests') }}</span>
|
||||
<div class="p-1.5 rounded-lg bg-purple-100 dark:bg-purple-900/30">
|
||||
<svg class="w-4 h-4 text-purple-600 dark:text-purple-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 12l3-3 3 3 4-4M8 21l4-4 4 4M3 4h18M4 4h16v12a1 1 0 01-1 1H5a1 1 0 01-1-1V4z" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<p class="text-2xl font-bold text-gray-900 dark:text-white">{{ formatNumber(Math.round(stats.summary.avg_daily_requests)) }}</p>
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">{{ t('admin.accounts.stats.avgDailyUsage') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 2: Today, Highest Cost, Highest Requests -->
|
||||
<div class="grid grid-cols-1 gap-4 lg:grid-cols-3">
|
||||
<!-- Today Overview -->
|
||||
<div class="card p-4">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<div class="p-1.5 rounded-lg bg-cyan-100 dark:bg-cyan-900/30">
|
||||
<svg class="w-4 h-4 text-cyan-600 dark:text-cyan-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</div>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.stats.todayOverview') }}</span>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.cost') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">${{ formatCost(stats.summary.today?.cost || 0) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.requests') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatNumber(stats.summary.today?.requests || 0) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">Tokens</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatTokens(stats.summary.today?.tokens || 0) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Highest Cost Day -->
|
||||
<div class="card p-4">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<div class="p-1.5 rounded-lg bg-orange-100 dark:bg-orange-900/30">
|
||||
<svg class="w-4 h-4 text-orange-600 dark:text-orange-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M17.657 18.657A8 8 0 016.343 7.343S7 9 9 10c0-2 .5-5 2.986-7C14 5 16.09 5.777 17.656 7.343A7.975 7.975 0 0120 13a7.975 7.975 0 01-2.343 5.657z" />
|
||||
</svg>
|
||||
</div>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.stats.highestCostDay') }}</span>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.date') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ stats.summary.highest_cost_day?.label || '-' }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.cost') }}</span>
|
||||
<span class="text-sm font-semibold text-orange-600 dark:text-orange-400">${{ formatCost(stats.summary.highest_cost_day?.cost || 0) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.requests') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatNumber(stats.summary.highest_cost_day?.requests || 0) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Highest Request Day -->
|
||||
<div class="card p-4">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<div class="p-1.5 rounded-lg bg-indigo-100 dark:bg-indigo-900/30">
|
||||
<svg class="w-4 h-4 text-indigo-600 dark:text-indigo-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 7h8m0 0v8m0-8l-8 8-4-4-6 6" />
|
||||
</svg>
|
||||
</div>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.stats.highestRequestDay') }}</span>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.date') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ stats.summary.highest_request_day?.label || '-' }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.requests') }}</span>
|
||||
<span class="text-sm font-semibold text-indigo-600 dark:text-indigo-400">{{ formatNumber(stats.summary.highest_request_day?.requests || 0) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.cost') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">${{ formatCost(stats.summary.highest_request_day?.cost || 0) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 3: Token Stats -->
|
||||
<div class="grid grid-cols-1 gap-4 lg:grid-cols-3">
|
||||
<!-- Accumulated Tokens -->
|
||||
<div class="card p-4">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<div class="p-1.5 rounded-lg bg-teal-100 dark:bg-teal-900/30">
|
||||
<svg class="w-4 h-4 text-teal-600 dark:text-teal-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M20 7l-8-4-8 4m16 0l-8 4m8-4v10l-8 4m0-10L4 7m8 4v10M4 7v10l8 4" />
|
||||
</svg>
|
||||
</div>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.stats.accumulatedTokens') }}</span>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.totalTokens') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatTokens(stats.summary.total_tokens) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.dailyAvgTokens') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatTokens(Math.round(stats.summary.avg_daily_tokens)) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Performance -->
|
||||
<div class="card p-4">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<div class="p-1.5 rounded-lg bg-rose-100 dark:bg-rose-900/30">
|
||||
<svg class="w-4 h-4 text-rose-600 dark:text-rose-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z" />
|
||||
</svg>
|
||||
</div>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.stats.performance') }}</span>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.avgResponseTime') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatDuration(stats.summary.avg_duration_ms) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.daysActive') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ stats.summary.actual_days_used }} / {{ stats.summary.days }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Recent Activity -->
|
||||
<div class="card p-4">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<div class="p-1.5 rounded-lg bg-lime-100 dark:bg-lime-900/30">
|
||||
<svg class="w-4 h-4 text-lime-600 dark:text-lime-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2" />
|
||||
</svg>
|
||||
</div>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.stats.recentActivity') }}</span>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.todayRequests') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatNumber(stats.summary.today?.requests || 0) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.todayTokens') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">{{ formatTokens(stats.summary.today?.tokens || 0) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between items-center">
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.stats.todayCost') }}</span>
|
||||
<span class="text-sm font-semibold text-gray-900 dark:text-white">${{ formatCost(stats.summary.today?.cost || 0) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Usage Trend Chart -->
|
||||
<div class="card p-4">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white mb-4">{{ t('admin.accounts.stats.usageTrend') }}</h3>
|
||||
<div class="h-64">
|
||||
<Line v-if="trendChartData" :data="trendChartData" :options="lineChartOptions" />
|
||||
<div v-else class="flex items-center justify-center h-full text-gray-500 dark:text-gray-400 text-sm">
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Distribution -->
|
||||
<ModelDistributionChart
|
||||
:model-stats="stats.models"
|
||||
:loading="false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<!-- No Data State -->
|
||||
<div v-else-if="!loading" class="flex flex-col items-center justify-center py-12 text-gray-500 dark:text-gray-400">
|
||||
<svg class="w-12 h-12 mb-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
<p class="text-sm">{{ t('admin.accounts.stats.noData') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<div class="flex justify-end">
|
||||
<button
|
||||
@click="handleClose"
|
||||
class="px-4 py-2 text-sm font-medium text-gray-700 dark:text-gray-300 bg-gray-100 dark:bg-dark-600 hover:bg-gray-200 dark:hover:bg-dark-500 rounded-lg transition-colors"
|
||||
>
|
||||
{{ t('common.close') }}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</Modal>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import {
|
||||
Chart as ChartJS,
|
||||
CategoryScale,
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
} from 'chart.js'
|
||||
import { Line } from 'vue-chartjs'
|
||||
import Modal from '@/components/common/Modal.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageStatsResponse } from '@/types'
|
||||
|
||||
ChartJS.register(
|
||||
CategoryScale,
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
)
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
show: boolean
|
||||
account: Account | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
(e: 'close'): void
|
||||
}>()
|
||||
|
||||
const loading = ref(false)
|
||||
const stats = ref<AccountUsageStatsResponse | null>(null)
|
||||
|
||||
// Dark mode detection
|
||||
const isDarkMode = computed(() => {
|
||||
return document.documentElement.classList.contains('dark')
|
||||
})
|
||||
|
||||
// Chart colors
|
||||
const chartColors = computed(() => ({
|
||||
text: isDarkMode.value ? '#e5e7eb' : '#374151',
|
||||
grid: isDarkMode.value ? '#374151' : '#e5e7eb',
|
||||
}))
|
||||
|
||||
// Line chart data
|
||||
const trendChartData = computed(() => {
|
||||
if (!stats.value?.history?.length) return null
|
||||
|
||||
return {
|
||||
labels: stats.value.history.map(h => h.label),
|
||||
datasets: [
|
||||
{
|
||||
label: t('admin.accounts.stats.cost') + ' (USD)',
|
||||
data: stats.value.history.map(h => h.cost),
|
||||
borderColor: '#3b82f6',
|
||||
backgroundColor: 'rgba(59, 130, 246, 0.1)',
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
yAxisID: 'y',
|
||||
},
|
||||
{
|
||||
label: t('admin.accounts.stats.requests'),
|
||||
data: stats.value.history.map(h => h.requests),
|
||||
borderColor: '#f97316',
|
||||
backgroundColor: 'rgba(249, 115, 22, 0.1)',
|
||||
fill: false,
|
||||
tension: 0.3,
|
||||
yAxisID: 'y1',
|
||||
},
|
||||
],
|
||||
}
|
||||
})
|
||||
|
||||
// Line chart options with dual Y-axis
|
||||
const lineChartOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
interaction: {
|
||||
intersect: false,
|
||||
mode: 'index' as const,
|
||||
},
|
||||
plugins: {
|
||||
legend: {
|
||||
position: 'top' as const,
|
||||
labels: {
|
||||
color: chartColors.value.text,
|
||||
usePointStyle: true,
|
||||
pointStyle: 'circle',
|
||||
padding: 15,
|
||||
font: {
|
||||
size: 11,
|
||||
},
|
||||
},
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: (context: any) => {
|
||||
const label = context.dataset.label || ''
|
||||
const value = context.raw
|
||||
if (label.includes('USD')) {
|
||||
return `${label}: $${formatCost(value)}`
|
||||
}
|
||||
return `${label}: ${formatNumber(value)}`
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
grid: {
|
||||
color: chartColors.value.grid,
|
||||
},
|
||||
ticks: {
|
||||
color: chartColors.value.text,
|
||||
font: {
|
||||
size: 10,
|
||||
},
|
||||
maxRotation: 45,
|
||||
minRotation: 0,
|
||||
},
|
||||
},
|
||||
y: {
|
||||
type: 'linear' as const,
|
||||
display: true,
|
||||
position: 'left' as const,
|
||||
grid: {
|
||||
color: chartColors.value.grid,
|
||||
},
|
||||
ticks: {
|
||||
color: '#3b82f6',
|
||||
font: {
|
||||
size: 10,
|
||||
},
|
||||
callback: (value: string | number) => '$' + formatCost(Number(value)),
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: t('admin.accounts.stats.cost') + ' (USD)',
|
||||
color: '#3b82f6',
|
||||
font: {
|
||||
size: 11,
|
||||
},
|
||||
},
|
||||
},
|
||||
y1: {
|
||||
type: 'linear' as const,
|
||||
display: true,
|
||||
position: 'right' as const,
|
||||
grid: {
|
||||
drawOnChartArea: false,
|
||||
},
|
||||
ticks: {
|
||||
color: '#f97316',
|
||||
font: {
|
||||
size: 10,
|
||||
},
|
||||
callback: (value: string | number) => formatNumber(Number(value)),
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: t('admin.accounts.stats.requests'),
|
||||
color: '#f97316',
|
||||
font: {
|
||||
size: 11,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
// Load stats when modal opens
|
||||
watch(() => props.show, async (newVal) => {
|
||||
if (newVal && props.account) {
|
||||
await loadStats()
|
||||
} else {
|
||||
stats.value = null
|
||||
}
|
||||
})
|
||||
|
||||
const loadStats = async () => {
|
||||
if (!props.account) return
|
||||
|
||||
loading.value = true
|
||||
try {
|
||||
stats.value = await adminAPI.accounts.getStats(props.account.id, 30)
|
||||
} catch (error) {
|
||||
console.error('Failed to load account stats:', error)
|
||||
stats.value = null
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
// Format helpers
|
||||
const formatCost = (value: number): string => {
|
||||
if (value >= 1000) {
|
||||
return (value / 1000).toFixed(2) + 'K'
|
||||
} else if (value >= 1) {
|
||||
return value.toFixed(2)
|
||||
} else if (value >= 0.01) {
|
||||
return value.toFixed(3)
|
||||
}
|
||||
return value.toFixed(4)
|
||||
}
|
||||
|
||||
const formatNumber = (value: number): string => {
|
||||
if (value >= 1_000_000) {
|
||||
return (value / 1_000_000).toFixed(2) + 'M'
|
||||
} else if (value >= 1_000) {
|
||||
return (value / 1_000).toFixed(2) + 'K'
|
||||
}
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatTokens = (value: number): string => {
|
||||
if (value >= 1_000_000_000) {
|
||||
return `${(value / 1_000_000_000).toFixed(2)}B`
|
||||
} else if (value >= 1_000_000) {
|
||||
return `${(value / 1_000_000).toFixed(2)}M`
|
||||
} else if (value >= 1_000) {
|
||||
return `${(value / 1_000).toFixed(2)}K`
|
||||
}
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatDuration = (ms: number): string => {
|
||||
if (ms >= 1000) {
|
||||
return `${(ms / 1000).toFixed(2)}s`
|
||||
}
|
||||
return `${Math.round(ms)}ms`
|
||||
}
|
||||
</script>
|
||||
@@ -1,7 +1,7 @@
|
||||
<template>
|
||||
<div v-if="account.type === 'oauth' || account.type === 'setup-token'">
|
||||
<!-- OAuth accounts: fetch real usage data -->
|
||||
<template v-if="account.type === 'oauth'">
|
||||
<div v-if="showUsageWindows">
|
||||
<!-- Anthropic OAuth accounts: fetch real usage data -->
|
||||
<template v-if="account.platform === 'anthropic' && account.type === 'oauth'">
|
||||
<!-- Loading state -->
|
||||
<div v-if="loading" class="space-y-1.5">
|
||||
<div class="flex items-center gap-1">
|
||||
@@ -63,20 +63,49 @@
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Setup Token accounts: show time-based window progress -->
|
||||
<template v-else-if="account.type === 'setup-token'">
|
||||
<!-- Anthropic Setup Token accounts: show time-based window progress -->
|
||||
<template v-else-if="account.platform === 'anthropic' && account.type === 'setup-token'">
|
||||
<SetupTokenTimeWindow :account="account" />
|
||||
</template>
|
||||
|
||||
<!-- OpenAI OAuth accounts: show Codex usage from extra field -->
|
||||
<template v-else-if="account.platform === 'openai' && account.type === 'oauth'">
|
||||
<div v-if="hasCodexUsage" class="space-y-1">
|
||||
<!-- 5h Window (Secondary) -->
|
||||
<UsageProgressBar
|
||||
v-if="codexSecondaryUsedPercent !== null"
|
||||
label="5h"
|
||||
:utilization="codexSecondaryUsedPercent"
|
||||
:resets-at="codexSecondaryResetAt"
|
||||
color="indigo"
|
||||
/>
|
||||
|
||||
<!-- Weekly Window (Primary) -->
|
||||
<UsageProgressBar
|
||||
v-if="codexPrimaryUsedPercent !== null"
|
||||
label="7d"
|
||||
:utilization="codexPrimaryUsedPercent"
|
||||
:resets-at="codexPrimaryResetAt"
|
||||
color="emerald"
|
||||
/>
|
||||
</div>
|
||||
<div v-else class="text-xs text-gray-400">-</div>
|
||||
</template>
|
||||
|
||||
<!-- Other accounts: no usage window -->
|
||||
<template v-else>
|
||||
<div class="text-xs text-gray-400">-</div>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- Non-OAuth accounts -->
|
||||
<!-- Non-OAuth/Setup-Token accounts -->
|
||||
<div v-else class="text-xs text-gray-400">
|
||||
-
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageInfo } from '@/types'
|
||||
import UsageProgressBar from './UsageProgressBar.vue'
|
||||
@@ -90,9 +119,50 @@ const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const usageInfo = ref<AccountUsageInfo | null>(null)
|
||||
|
||||
// Show usage windows for OAuth and Setup Token accounts
|
||||
const showUsageWindows = computed(() =>
|
||||
props.account.type === 'oauth' || props.account.type === 'setup-token'
|
||||
)
|
||||
|
||||
// OpenAI Codex usage computed properties
|
||||
const hasCodexUsage = computed(() => {
|
||||
const extra = props.account.extra
|
||||
return extra && (
|
||||
extra.codex_primary_used_percent !== undefined ||
|
||||
extra.codex_secondary_used_percent !== undefined
|
||||
)
|
||||
})
|
||||
|
||||
const codexPrimaryUsedPercent = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_primary_used_percent === undefined) return null
|
||||
return extra.codex_primary_used_percent
|
||||
})
|
||||
|
||||
const codexSecondaryUsedPercent = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_secondary_used_percent === undefined) return null
|
||||
return extra.codex_secondary_used_percent
|
||||
})
|
||||
|
||||
const codexPrimaryResetAt = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_primary_reset_after_seconds === undefined) return null
|
||||
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
})
|
||||
|
||||
const codexSecondaryResetAt = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_secondary_reset_after_seconds === undefined) return null
|
||||
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
})
|
||||
|
||||
const loadUsage = async () => {
|
||||
// Only fetch usage for OAuth accounts (Setup Token uses local time-based calculation)
|
||||
if (props.account.type !== 'oauth') return
|
||||
// Only fetch usage for Anthropic OAuth accounts
|
||||
// OpenAI usage comes from account.extra field (updated during forwarding)
|
||||
if (props.account.platform !== 'anthropic' || props.account.type !== 'oauth') return
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
@@ -47,83 +47,161 @@
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Platform Selection - Segmented Control Style -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||
<div class="grid grid-cols-2 gap-3 mt-2">
|
||||
<label
|
||||
<label class="input-label">{{ t('admin.accounts.platform') }}</label>
|
||||
<div class="flex rounded-lg bg-gray-100 dark:bg-dark-700 p-1 mt-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="form.platform = 'anthropic'"
|
||||
:class="[
|
||||
'relative flex cursor-pointer rounded-lg border-2 p-4 transition-all',
|
||||
accountCategory === 'oauth-based'
|
||||
? 'border-primary-500 bg-primary-50 dark:bg-primary-900/20'
|
||||
: 'border-gray-200 dark:border-dark-600 hover:border-primary-300'
|
||||
'flex-1 flex items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
|
||||
form.platform === 'anthropic'
|
||||
? 'bg-white dark:bg-dark-600 text-orange-600 dark:text-orange-400 shadow-sm'
|
||||
: 'text-gray-600 dark:text-gray-400 hover:text-gray-900 dark:hover:text-gray-200'
|
||||
]"
|
||||
>
|
||||
<input
|
||||
v-model="accountCategory"
|
||||
type="radio"
|
||||
value="oauth-based"
|
||||
class="sr-only"
|
||||
/>
|
||||
<div class="flex items-center gap-3">
|
||||
<div class="flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br from-orange-500 to-orange-600">
|
||||
<svg class="w-5 h-5 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M9.813 15.904L9 18.75l-.813-2.846a4.5 4.5 0 00-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 003.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 003.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 00-3.09 3.09zM18.259 8.715L18 9.75l-.259-1.035a3.375 3.375 0 00-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 002.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 002.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 00-2.456 2.456z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.claudeCode') }}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.oauthSetupToken') }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="accountCategory === 'oauth-based'"
|
||||
class="absolute right-2 top-2 flex h-5 w-5 items-center justify-center rounded-full bg-primary-500"
|
||||
>
|
||||
<svg class="w-3 h-3 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="3">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" />
|
||||
</svg>
|
||||
</div>
|
||||
</label>
|
||||
|
||||
<label
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M9.813 15.904L9 18.75l-.813-2.846a4.5 4.5 0 00-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 003.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 003.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 00-3.09 3.09zM18.259 8.715L18 9.75l-.259-1.035a3.375 3.375 0 00-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 002.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 002.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 00-2.456 2.456z" />
|
||||
</svg>
|
||||
Anthropic
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="form.platform = 'openai'"
|
||||
:class="[
|
||||
'relative flex cursor-pointer rounded-lg border-2 p-4 transition-all',
|
||||
accountCategory === 'apikey'
|
||||
? 'border-primary-500 bg-primary-50 dark:bg-primary-900/20'
|
||||
: 'border-gray-200 dark:border-dark-600 hover:border-primary-300'
|
||||
'flex-1 flex items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
|
||||
form.platform === 'openai'
|
||||
? 'bg-white dark:bg-dark-600 text-green-600 dark:text-green-400 shadow-sm'
|
||||
: 'text-gray-600 dark:text-gray-400 hover:text-gray-900 dark:hover:text-gray-200'
|
||||
]"
|
||||
>
|
||||
<input
|
||||
v-model="accountCategory"
|
||||
type="radio"
|
||||
value="apikey"
|
||||
class="sr-only"
|
||||
/>
|
||||
<div class="flex items-center gap-3">
|
||||
<div class="flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br from-purple-500 to-purple-600">
|
||||
<svg class="w-5 h-5 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M15.75 5.25a3 3 0 013 3m3 0a6 6 0 01-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1121.75 8.25z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-semibold text-gray-900 dark:text-white">{{ t('admin.accounts.claudeConsole') }}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.apiKey') }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="accountCategory === 'apikey'"
|
||||
class="absolute right-2 top-2 flex h-5 w-5 items-center justify-center rounded-full bg-primary-500"
|
||||
>
|
||||
<svg class="w-3 h-3 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="3">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" />
|
||||
</svg>
|
||||
</div>
|
||||
</label>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M3.75 13.5l10.5-11.25L12 10.5h8.25L9.75 21.75 12 13.5H3.75z" />
|
||||
</svg>
|
||||
OpenAI
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Add Method (only for OAuth-based type) -->
|
||||
<div v-if="isOAuthFlow">
|
||||
<!-- Account Type Selection (Anthropic) -->
|
||||
<div v-if="form.platform === 'anthropic'">
|
||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||
<div class="grid grid-cols-2 gap-3 mt-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'oauth-based'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 transition-all text-left',
|
||||
accountCategory === 'oauth-based'
|
||||
? 'border-orange-500 bg-orange-50 dark:bg-orange-900/20'
|
||||
: 'border-gray-200 dark:border-dark-600 hover:border-orange-300 dark:hover:border-orange-700'
|
||||
]"
|
||||
>
|
||||
<div :class="[
|
||||
'flex h-8 w-8 items-center justify-center rounded-lg',
|
||||
accountCategory === 'oauth-based'
|
||||
? 'bg-orange-500 text-white'
|
||||
: 'bg-gray-100 dark:bg-dark-600 text-gray-500 dark:text-gray-400'
|
||||
]">
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M9.813 15.904L9 18.75l-.813-2.846a4.5 4.5 0 00-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 003.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 003.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 00-3.09 3.09zM18.259 8.715L18 9.75l-.259-1.035a3.375 3.375 0 00-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 002.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 002.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 00-2.456 2.456z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{ t('admin.accounts.claudeCode') }}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.oauthSetupToken') }}</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'apikey'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 transition-all text-left',
|
||||
accountCategory === 'apikey'
|
||||
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
|
||||
: 'border-gray-200 dark:border-dark-600 hover:border-purple-300 dark:hover:border-purple-700'
|
||||
]"
|
||||
>
|
||||
<div :class="[
|
||||
'flex h-8 w-8 items-center justify-center rounded-lg',
|
||||
accountCategory === 'apikey'
|
||||
? 'bg-purple-500 text-white'
|
||||
: 'bg-gray-100 dark:bg-dark-600 text-gray-500 dark:text-gray-400'
|
||||
]">
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M15.75 5.25a3 3 0 013 3m3 0a6 6 0 01-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1121.75 8.25z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{ t('admin.accounts.claudeConsole') }}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.apiKey') }}</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Account Type Selection (OpenAI) -->
|
||||
<div v-if="form.platform === 'openai'">
|
||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||
<div class="grid grid-cols-2 gap-3 mt-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'oauth-based'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 transition-all text-left',
|
||||
accountCategory === 'oauth-based'
|
||||
? 'border-green-500 bg-green-50 dark:bg-green-900/20'
|
||||
: 'border-gray-200 dark:border-dark-600 hover:border-green-300 dark:hover:border-green-700'
|
||||
]"
|
||||
>
|
||||
<div :class="[
|
||||
'flex h-8 w-8 items-center justify-center rounded-lg',
|
||||
accountCategory === 'oauth-based'
|
||||
? 'bg-green-500 text-white'
|
||||
: 'bg-gray-100 dark:bg-dark-600 text-gray-500 dark:text-gray-400'
|
||||
]">
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M15.75 5.25a3 3 0 013 3m3 0a6 6 0 01-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1121.75 8.25z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">OAuth</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">ChatGPT OAuth</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'apikey'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 transition-all text-left',
|
||||
accountCategory === 'apikey'
|
||||
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
|
||||
: 'border-gray-200 dark:border-dark-600 hover:border-purple-300 dark:hover:border-purple-700'
|
||||
]"
|
||||
>
|
||||
<div :class="[
|
||||
'flex h-8 w-8 items-center justify-center rounded-lg',
|
||||
accountCategory === 'apikey'
|
||||
? 'bg-purple-500 text-white'
|
||||
: 'bg-gray-100 dark:bg-dark-600 text-gray-500 dark:text-gray-400'
|
||||
]">
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M15.75 5.25a3 3 0 013 3m3 0a6 6 0 01-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1121.75 8.25z" />
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">API Key</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">Responses API</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Add Method (only for Anthropic OAuth-based type) -->
|
||||
<div v-if="form.platform === 'anthropic' && isOAuthFlow">
|
||||
<label class="input-label">{{ t('admin.accounts.addMethod') }}</label>
|
||||
<div class="flex gap-4 mt-2">
|
||||
<label class="flex cursor-pointer items-center">
|
||||
@@ -155,7 +233,7 @@
|
||||
v-model="apiKeyBaseUrl"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder="https://api.anthropic.com"
|
||||
:placeholder="form.platform === 'openai' ? 'https://api.openai.com' : 'https://api.anthropic.com'"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.baseUrlHint') }}</p>
|
||||
</div>
|
||||
@@ -166,7 +244,7 @@
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.apiKeyPlaceholder')"
|
||||
:placeholder="form.platform === 'openai' ? 'sk-proj-...' : 'sk-ant-...'"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.apiKeyHint') }}</p>
|
||||
</div>
|
||||
@@ -418,8 +496,8 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Intercept Warmup Requests (all account types) -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<!-- Intercept Warmup Requests (Anthropic only) -->
|
||||
<div v-if="form.platform === 'anthropic'" class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.interceptWarmupRequests') }}</label>
|
||||
@@ -477,6 +555,7 @@
|
||||
<GroupSelector
|
||||
v-model="form.group_ids"
|
||||
:groups="groups"
|
||||
:platform="form.platform"
|
||||
/>
|
||||
|
||||
<div class="flex justify-end gap-3 pt-4">
|
||||
@@ -510,14 +589,16 @@
|
||||
<div v-else class="space-y-5">
|
||||
<OAuthAuthorizationFlow
|
||||
ref="oauthFlowRef"
|
||||
:add-method="addMethod"
|
||||
:auth-url="oauth.authUrl.value"
|
||||
:session-id="oauth.sessionId.value"
|
||||
:loading="oauth.loading.value"
|
||||
:error="oauth.error.value"
|
||||
:show-help="true"
|
||||
:add-method="form.platform === 'openai' ? 'oauth' : addMethod"
|
||||
:auth-url="currentAuthUrl"
|
||||
:session-id="currentSessionId"
|
||||
:loading="currentOAuthLoading"
|
||||
:error="currentOAuthError"
|
||||
:show-help="form.platform !== 'openai'"
|
||||
:show-proxy-warning="!!form.proxy_id"
|
||||
:allow-multiple="true"
|
||||
:allow-multiple="form.platform !== 'openai'"
|
||||
:show-cookie-option="form.platform !== 'openai'"
|
||||
:platform="form.platform"
|
||||
@generate-url="handleGenerateUrl"
|
||||
@cookie-auth="handleCookieAuth"
|
||||
/>
|
||||
@@ -538,7 +619,7 @@
|
||||
@click="handleExchangeCode"
|
||||
>
|
||||
<svg
|
||||
v-if="oauth.loading.value"
|
||||
v-if="currentOAuthLoading"
|
||||
class="animate-spin -ml-1 mr-2 h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
@@ -546,7 +627,7 @@
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
{{ oauth.loading.value ? t('admin.accounts.oauth.verifying') : t('admin.accounts.oauth.completeAuth') }}
|
||||
{{ currentOAuthLoading ? t('admin.accounts.oauth.verifying') : t('admin.accounts.oauth.completeAuth') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -559,6 +640,7 @@ import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import { useAccountOAuth, type AddMethod, type AuthInputMethod } from '@/composables/useAccountOAuth'
|
||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||
import type { Proxy, Group, AccountPlatform, AccountType } from '@/types'
|
||||
import Modal from '@/components/common/Modal.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
@@ -590,8 +672,26 @@ const emit = defineEmits<{
|
||||
|
||||
const appStore = useAppStore()
|
||||
|
||||
// OAuth composable
|
||||
const oauth = useAccountOAuth()
|
||||
// OAuth composables
|
||||
const oauth = useAccountOAuth() // For Anthropic OAuth
|
||||
const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth
|
||||
|
||||
// Computed: current OAuth state for template binding
|
||||
const currentAuthUrl = computed(() => {
|
||||
return form.platform === 'openai' ? openaiOAuth.authUrl.value : oauth.authUrl.value
|
||||
})
|
||||
|
||||
const currentSessionId = computed(() => {
|
||||
return form.platform === 'openai' ? openaiOAuth.sessionId.value : oauth.sessionId.value
|
||||
})
|
||||
|
||||
const currentOAuthLoading = computed(() => {
|
||||
return form.platform === 'openai' ? openaiOAuth.loading.value : oauth.loading.value
|
||||
})
|
||||
|
||||
const currentOAuthError = computed(() => {
|
||||
return form.platform === 'openai' ? openaiOAuth.error.value : oauth.error.value
|
||||
})
|
||||
|
||||
// Refs
|
||||
const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
|
||||
@@ -617,8 +717,8 @@ const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
|
||||
// Common models for whitelist
|
||||
const commonModels = [
|
||||
// Common models for whitelist - Anthropic
|
||||
const anthropicModels = [
|
||||
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
||||
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
||||
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
||||
@@ -629,8 +729,24 @@ const commonModels = [
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' }
|
||||
]
|
||||
|
||||
// Preset mappings for quick add
|
||||
const presetMappings = [
|
||||
// Common models for whitelist - OpenAI
|
||||
const openaiModels = [
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' }
|
||||
]
|
||||
|
||||
// Computed: current models based on platform
|
||||
const commonModels = computed(() => {
|
||||
return form.platform === 'openai' ? openaiModels : anthropicModels
|
||||
})
|
||||
|
||||
// Preset mappings for quick add - Anthropic
|
||||
const anthropicPresetMappings = [
|
||||
{ label: 'Sonnet 4', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-20250514', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
|
||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-5-20250929', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||
{ label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||
@@ -639,6 +755,21 @@ const presetMappings = [
|
||||
{ label: 'Opus->Sonnet', from: 'claude-opus-4-5-20251101', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
|
||||
]
|
||||
|
||||
// Preset mappings for quick add - OpenAI
|
||||
const openaiPresetMappings = [
|
||||
{ label: 'GPT-5.2', from: 'gpt-5.2-2025-12-11', to: 'gpt-5.2-2025-12-11', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||
{ label: 'GPT-5.2 Codex', from: 'gpt-5.2-codex', to: 'gpt-5.2-codex', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
|
||||
{ label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||
{ label: 'Codex Max', from: 'gpt-5.1-codex-max', to: 'gpt-5.1-codex-max', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||
{ label: 'Codex Mini', from: 'gpt-5.1-codex-mini', to: 'gpt-5.1-codex-mini', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||
{ label: 'Max->Codex', from: 'gpt-5.1-codex-max', to: 'gpt-5.1-codex', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
|
||||
]
|
||||
|
||||
// Computed: current preset mappings based on platform
|
||||
const presetMappings = computed(() => {
|
||||
return form.platform === 'openai' ? openaiPresetMappings : anthropicPresetMappings
|
||||
})
|
||||
|
||||
// Common HTTP error codes for quick selection
|
||||
const commonErrorCodes = [
|
||||
{ value: 401, label: 'Unauthorized' },
|
||||
@@ -670,6 +801,9 @@ const isManualInputMethod = computed(() => {
|
||||
|
||||
const canExchangeCode = computed(() => {
|
||||
const authCode = oauthFlowRef.value?.authCode || ''
|
||||
if (form.platform === 'openai') {
|
||||
return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value
|
||||
}
|
||||
return authCode.trim() && oauth.sessionId.value && !oauth.loading.value
|
||||
})
|
||||
|
||||
@@ -689,6 +823,20 @@ watch([accountCategory, addMethod], ([category, method]) => {
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
// Reset platform-specific settings when platform changes
|
||||
watch(() => form.platform, (newPlatform) => {
|
||||
// Reset base URL based on platform
|
||||
apiKeyBaseUrl.value = newPlatform === 'openai'
|
||||
? 'https://api.openai.com'
|
||||
: 'https://api.anthropic.com'
|
||||
// Clear model-related settings
|
||||
allowedModels.value = []
|
||||
modelMappings.value = []
|
||||
// Reset OAuth states
|
||||
oauth.resetState()
|
||||
openaiOAuth.resetState()
|
||||
})
|
||||
|
||||
// Model mapping helpers
|
||||
const addModelMapping = () => {
|
||||
modelMappings.value.push({ from: '', to: '' })
|
||||
@@ -786,6 +934,7 @@ const resetForm = () => {
|
||||
customErrorCodeInput.value = null
|
||||
interceptWarmupRequests.value = false
|
||||
oauth.resetState()
|
||||
openaiOAuth.resetState()
|
||||
oauthFlowRef.value?.reset()
|
||||
}
|
||||
|
||||
@@ -810,9 +959,14 @@ const handleSubmit = async () => {
|
||||
return
|
||||
}
|
||||
|
||||
// Determine default base URL based on platform
|
||||
const defaultBaseUrl = form.platform === 'openai'
|
||||
? 'https://api.openai.com'
|
||||
: 'https://api.anthropic.com'
|
||||
|
||||
// Build credentials with optional model mapping
|
||||
const credentials: Record<string, unknown> = {
|
||||
base_url: apiKeyBaseUrl.value.trim() || 'https://api.anthropic.com',
|
||||
base_url: apiKeyBaseUrl.value.trim() || defaultBaseUrl,
|
||||
api_key: apiKeyValue.value.trim()
|
||||
}
|
||||
|
||||
@@ -837,7 +991,10 @@ const handleSubmit = async () => {
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
await adminAPI.accounts.create(form)
|
||||
await adminAPI.accounts.create({
|
||||
...form,
|
||||
group_ids: form.group_ids
|
||||
})
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
emit('created')
|
||||
handleClose()
|
||||
@@ -851,15 +1008,72 @@ const handleSubmit = async () => {
|
||||
const goBackToBasicInfo = () => {
|
||||
step.value = 1
|
||||
oauth.resetState()
|
||||
openaiOAuth.resetState()
|
||||
oauthFlowRef.value?.reset()
|
||||
}
|
||||
|
||||
const handleGenerateUrl = async () => {
|
||||
await oauth.generateAuthUrl(addMethod.value, form.proxy_id)
|
||||
if (form.platform === 'openai') {
|
||||
await openaiOAuth.generateAuthUrl(form.proxy_id)
|
||||
} else {
|
||||
await oauth.generateAuthUrl(addMethod.value, form.proxy_id)
|
||||
}
|
||||
}
|
||||
|
||||
const handleExchangeCode = async () => {
|
||||
const authCode = oauthFlowRef.value?.authCode || ''
|
||||
|
||||
// For OpenAI
|
||||
if (form.platform === 'openai') {
|
||||
if (!authCode.trim() || !openaiOAuth.sessionId.value) return
|
||||
|
||||
openaiOAuth.loading.value = true
|
||||
openaiOAuth.error.value = ''
|
||||
|
||||
try {
|
||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
||||
authCode.trim(),
|
||||
openaiOAuth.sessionId.value,
|
||||
form.proxy_id
|
||||
)
|
||||
|
||||
if (!tokenInfo) {
|
||||
return // Error already handled by composable
|
||||
}
|
||||
|
||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
||||
const extra = openaiOAuth.buildExtraInfo(tokenInfo)
|
||||
|
||||
// Merge interceptWarmupRequests into credentials
|
||||
if (interceptWarmupRequests.value) {
|
||||
credentials.intercept_warmup_requests = true
|
||||
}
|
||||
|
||||
await adminAPI.accounts.create({
|
||||
name: form.name,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
credentials,
|
||||
extra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
priority: form.priority,
|
||||
group_ids: form.group_ids
|
||||
})
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
emit('created')
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(openaiOAuth.error.value)
|
||||
} finally {
|
||||
openaiOAuth.loading.value = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For Anthropic
|
||||
if (!authCode.trim() || !oauth.sessionId.value) return
|
||||
|
||||
oauth.loading.value = true
|
||||
@@ -893,7 +1107,8 @@ const handleExchangeCode = async () => {
|
||||
extra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
priority: form.priority
|
||||
priority: form.priority,
|
||||
group_ids: form.group_ids
|
||||
})
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
v-model="editBaseUrl"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder="https://api.anthropic.com"
|
||||
:placeholder="account.platform === 'openai' ? 'https://api.openai.com' : 'https://api.anthropic.com'"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.baseUrlHint') }}</p>
|
||||
</div>
|
||||
@@ -34,7 +34,7 @@
|
||||
v-model="editApiKey"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.leaveEmptyToKeep')"
|
||||
:placeholder="account.platform === 'openai' ? 'sk-proj-...' : 'sk-ant-...'"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
||||
</div>
|
||||
@@ -286,8 +286,8 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Intercept Warmup Requests (all account types) -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<!-- Intercept Warmup Requests (Anthropic only) -->
|
||||
<div v-if="account?.platform === 'anthropic'" class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.interceptWarmupRequests') }}</label>
|
||||
@@ -352,6 +352,7 @@
|
||||
<GroupSelector
|
||||
v-model="form.group_ids"
|
||||
:groups="groups"
|
||||
:platform="account?.platform"
|
||||
/>
|
||||
|
||||
<div class="flex justify-end gap-3 pt-4">
|
||||
@@ -428,8 +429,8 @@ const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
|
||||
// Common models for whitelist
|
||||
const commonModels = [
|
||||
// Common models for whitelist - Anthropic
|
||||
const anthropicModels = [
|
||||
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
||||
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
||||
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
||||
@@ -440,8 +441,24 @@ const commonModels = [
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' }
|
||||
]
|
||||
|
||||
// Preset mappings for quick add
|
||||
const presetMappings = [
|
||||
// Common models for whitelist - OpenAI
|
||||
const openaiModels = [
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' }
|
||||
]
|
||||
|
||||
// Computed: current models based on platform
|
||||
const commonModels = computed(() => {
|
||||
return props.account?.platform === 'openai' ? openaiModels : anthropicModels
|
||||
})
|
||||
|
||||
// Preset mappings for quick add - Anthropic
|
||||
const anthropicPresetMappings = [
|
||||
{ label: 'Sonnet 4', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-20250514', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
|
||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-5-20250929', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||
{ label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||
@@ -450,6 +467,26 @@ const presetMappings = [
|
||||
{ label: 'Opus->Sonnet', from: 'claude-opus-4-5-20251101', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
|
||||
]
|
||||
|
||||
// Preset mappings for quick add - OpenAI
|
||||
const openaiPresetMappings = [
|
||||
{ label: 'GPT-5.2', from: 'gpt-5.2-2025-12-11', to: 'gpt-5.2-2025-12-11', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||
{ label: 'GPT-5.2 Codex', from: 'gpt-5.2-codex', to: 'gpt-5.2-codex', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
|
||||
{ label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||
{ label: 'Codex Max', from: 'gpt-5.1-codex-max', to: 'gpt-5.1-codex-max', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||
{ label: 'Codex Mini', from: 'gpt-5.1-codex-mini', to: 'gpt-5.1-codex-mini', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||
{ label: 'Max->Codex', from: 'gpt-5.1-codex-max', to: 'gpt-5.1-codex', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
|
||||
]
|
||||
|
||||
// Computed: current preset mappings based on platform
|
||||
const presetMappings = computed(() => {
|
||||
return props.account?.platform === 'openai' ? openaiPresetMappings : anthropicPresetMappings
|
||||
})
|
||||
|
||||
// Computed: default base URL based on platform
|
||||
const defaultBaseUrl = computed(() => {
|
||||
return props.account?.platform === 'openai' ? 'https://api.openai.com' : 'https://api.anthropic.com'
|
||||
})
|
||||
|
||||
// Common HTTP error codes for quick selection
|
||||
const commonErrorCodes = [
|
||||
{ value: 401, label: 'Unauthorized' },
|
||||
@@ -492,7 +529,8 @@ watch(() => props.account, (newAccount) => {
|
||||
// Initialize API Key fields for apikey type
|
||||
if (newAccount.type === 'apikey' && newAccount.credentials) {
|
||||
const credentials = newAccount.credentials as Record<string, unknown>
|
||||
editBaseUrl.value = credentials.base_url as string || 'https://api.anthropic.com'
|
||||
const platformDefaultUrl = newAccount.platform === 'openai' ? 'https://api.openai.com' : 'https://api.anthropic.com'
|
||||
editBaseUrl.value = credentials.base_url as string || platformDefaultUrl
|
||||
|
||||
// Load model mappings and detect mode
|
||||
const existingMappings = credentials.model_mapping as Record<string, string> | undefined
|
||||
@@ -529,7 +567,8 @@ watch(() => props.account, (newAccount) => {
|
||||
selectedErrorCodes.value = []
|
||||
}
|
||||
} else {
|
||||
editBaseUrl.value = 'https://api.anthropic.com'
|
||||
const platformDefaultUrl = newAccount.platform === 'openai' ? 'https://api.openai.com' : 'https://api.anthropic.com'
|
||||
editBaseUrl.value = platformDefaultUrl
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
@@ -628,7 +667,7 @@ const handleSubmit = async () => {
|
||||
// For apikey type, handle credentials update
|
||||
if (props.account.type === 'apikey') {
|
||||
const currentCredentials = props.account.credentials as Record<string, unknown> || {}
|
||||
const newBaseUrl = editBaseUrl.value.trim() || 'https://api.anthropic.com'
|
||||
const newBaseUrl = editBaseUrl.value.trim() || defaultBaseUrl.value
|
||||
const modelMapping = buildModelMappingObject()
|
||||
|
||||
// Always update credentials for apikey type to handle model mapping changes
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user