mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-08 01:00:21 +08:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dacf3a2a6e | ||
|
|
e6add93ae3 | ||
|
|
b2273ec695 | ||
|
|
aa89777dda | ||
|
|
1e1f3c0c74 | ||
|
|
1fab9204eb | ||
|
|
dbd3e71637 | ||
|
|
974f67211b | ||
|
|
0338c83b90 | ||
|
|
c6b3de1199 |
38
.github/workflows/backend-ci.yml
vendored
Normal file
38
.github/workflows/backend-ci.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: backend/go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
|
- name: Run tests
|
||||||
|
working-directory: backend
|
||||||
|
run: go test ./...
|
||||||
|
|
||||||
|
golangci-lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: backend/go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v9
|
||||||
|
with:
|
||||||
|
version: v2.7
|
||||||
|
args: --timeout=5m
|
||||||
|
working-directory: backend
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -81,7 +81,12 @@ build/
|
|||||||
release/
|
release/
|
||||||
|
|
||||||
# 后端嵌入的前端构建产物
|
# 后端嵌入的前端构建产物
|
||||||
|
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
|
||||||
|
# while still ignoring generated frontend build outputs.
|
||||||
backend/internal/web/dist/
|
backend/internal/web/dist/
|
||||||
|
!backend/internal/web/dist/
|
||||||
|
backend/internal/web/dist/*
|
||||||
|
!backend/internal/web/dist/.keep
|
||||||
|
|
||||||
# 后端运行时缓存数据
|
# 后端运行时缓存数据
|
||||||
backend/data/
|
backend/data/
|
||||||
@@ -92,4 +97,4 @@ backend/data/
|
|||||||
tests
|
tests
|
||||||
CLAUDE.md
|
CLAUDE.md
|
||||||
.claude
|
.claude
|
||||||
scripts
|
scripts
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ builds:
|
|||||||
dir: backend
|
dir: backend
|
||||||
main: ./cmd/server
|
main: ./cmd/server
|
||||||
binary: sub2api
|
binary: sub2api
|
||||||
|
flags:
|
||||||
|
- -tags=embed
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=0
|
- CGO_ENABLED=0
|
||||||
goos:
|
goos:
|
||||||
|
|||||||
11
Dockerfile
11
Dockerfile
@@ -40,14 +40,15 @@ WORKDIR /app/backend
|
|||||||
COPY backend/go.mod backend/go.sum ./
|
COPY backend/go.mod backend/go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
# Copy frontend dist from previous stage
|
# Copy backend source first
|
||||||
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
|
|
||||||
|
|
||||||
# Copy backend source
|
|
||||||
COPY backend/ ./
|
COPY backend/ ./
|
||||||
|
|
||||||
# Build the binary (BuildType=release for CI builds)
|
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
|
||||||
|
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||||
|
|
||||||
|
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||||
|
-tags embed \
|
||||||
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||||
-o /app/sub2api \
|
-o /app/sub2api \
|
||||||
./cmd/server
|
./cmd/server
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -220,21 +220,21 @@ cd sub2api
|
|||||||
cd frontend
|
cd frontend
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
# Output will be in ../backend/internal/web/dist/
|
||||||
|
|
||||||
# 3. Copy frontend build to backend (for embedding)
|
# 3. Build backend with embedded frontend
|
||||||
cp -r dist ../backend/internal/web/
|
|
||||||
|
|
||||||
# 4. Build backend (requires frontend dist to be present)
|
|
||||||
cd ../backend
|
cd ../backend
|
||||||
go build -o sub2api ./cmd/server
|
go build -tags embed -o sub2api ./cmd/server
|
||||||
|
|
||||||
# 5. Create configuration file
|
# 4. Create configuration file
|
||||||
cp ../deploy/config.example.yaml ./config.yaml
|
cp ../deploy/config.example.yaml ./config.yaml
|
||||||
|
|
||||||
# 6. Edit configuration
|
# 5. Edit configuration
|
||||||
nano config.yaml
|
nano config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
|
||||||
|
|
||||||
**Key configuration in `config.yaml`:**
|
**Key configuration in `config.yaml`:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -265,7 +265,7 @@ default:
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 7. Run the application
|
# 6. Run the application
|
||||||
./sub2api
|
./sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
16
README_CN.md
16
README_CN.md
@@ -220,21 +220,21 @@ cd sub2api
|
|||||||
cd frontend
|
cd frontend
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
# 构建产物输出到 ../backend/internal/web/dist/
|
||||||
|
|
||||||
# 3. 复制前端构建产物到后端(用于嵌入)
|
# 3. 编译后端(嵌入前端)
|
||||||
cp -r dist ../backend/internal/web/
|
|
||||||
|
|
||||||
# 4. 编译后端(需要前端 dist 目录存在)
|
|
||||||
cd ../backend
|
cd ../backend
|
||||||
go build -o sub2api ./cmd/server
|
go build -tags embed -o sub2api ./cmd/server
|
||||||
|
|
||||||
# 5. 创建配置文件
|
# 4. 创建配置文件
|
||||||
cp ../deploy/config.example.yaml ./config.yaml
|
cp ../deploy/config.example.yaml ./config.yaml
|
||||||
|
|
||||||
# 6. 编辑配置
|
# 5. 编辑配置
|
||||||
nano config.yaml
|
nano config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
|
||||||
|
|
||||||
**`config.yaml` 关键配置:**
|
**`config.yaml` 关键配置:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -265,7 +265,7 @@ default:
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 7. 运行应用
|
# 6. 运行应用
|
||||||
./sub2api
|
./sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
587
backend/.golangci.yml
Normal file
587
backend/.golangci.yml
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
linters:
|
||||||
|
default: none
|
||||||
|
enable:
|
||||||
|
- depguard
|
||||||
|
- errcheck
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- staticcheck
|
||||||
|
- unused
|
||||||
|
|
||||||
|
settings:
|
||||||
|
depguard:
|
||||||
|
rules:
|
||||||
|
# Enforce: service must not depend on repository.
|
||||||
|
service-no-repository:
|
||||||
|
list-mode: original
|
||||||
|
files:
|
||||||
|
- internal/service/**
|
||||||
|
deny:
|
||||||
|
- pkg: sub2api/internal/repository
|
||||||
|
desc: "service must not import repository"
|
||||||
|
errcheck:
|
||||||
|
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||||
|
# Such cases aren't reported by default.
|
||||||
|
# Default: false
|
||||||
|
check-type-assertions: true
|
||||||
|
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
|
||||||
|
# Such cases aren't reported by default.
|
||||||
|
# Default: false
|
||||||
|
check-blank: false
|
||||||
|
# To disable the errcheck built-in exclude list.
|
||||||
|
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||||
|
# Default: false
|
||||||
|
disable-default-exclusions: true
|
||||||
|
# List of functions to exclude from checking, where each entry is a single function to exclude.
|
||||||
|
# See https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||||
|
exclude-functions:
|
||||||
|
- io/ioutil.ReadFile
|
||||||
|
- io.Copy(*bytes.Buffer)
|
||||||
|
- io.Copy(os.Stdout)
|
||||||
|
- fmt.Println
|
||||||
|
- fmt.Print
|
||||||
|
- fmt.Printf
|
||||||
|
- fmt.Fprint
|
||||||
|
- fmt.Fprintf
|
||||||
|
- fmt.Fprintln
|
||||||
|
# Display function signature instead of selector.
|
||||||
|
# Default: false
|
||||||
|
verbose: true
|
||||||
|
ineffassign:
|
||||||
|
# Check escaping variables of type error, may cause false positives.
|
||||||
|
# Default: false
|
||||||
|
check-escaping-errors: true
|
||||||
|
staticcheck:
|
||||||
|
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||||
|
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||||
|
dot-import-whitelist:
|
||||||
|
- fmt
|
||||||
|
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||||
|
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||||
|
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||||
|
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||||
|
# Default: ["200", "400", "404", "500"]
|
||||||
|
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||||
|
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||||
|
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||||
|
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||||
|
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||||
|
checks:
|
||||||
|
# Invalid regular expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1000
|
||||||
|
- SA1000
|
||||||
|
# Invalid template.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1001
|
||||||
|
- SA1001
|
||||||
|
# Invalid format in 'time.Parse'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1002
|
||||||
|
- SA1002
|
||||||
|
# Unsupported argument to functions in 'encoding/binary'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1003
|
||||||
|
- SA1003
|
||||||
|
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1004
|
||||||
|
- SA1004
|
||||||
|
# Invalid first argument to 'exec.Command'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1005
|
||||||
|
- SA1005
|
||||||
|
# 'Printf' with dynamic first argument and no further arguments.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1006
|
||||||
|
- SA1006
|
||||||
|
# Invalid URL in 'net/url.Parse'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1007
|
||||||
|
- SA1007
|
||||||
|
# Non-canonical key in 'http.Header' map.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1008
|
||||||
|
- SA1008
|
||||||
|
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1010
|
||||||
|
- SA1010
|
||||||
|
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1011
|
||||||
|
- SA1011
|
||||||
|
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1012
|
||||||
|
- SA1012
|
||||||
|
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1013
|
||||||
|
- SA1013
|
||||||
|
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1014
|
||||||
|
- SA1014
|
||||||
|
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1015
|
||||||
|
- SA1015
|
||||||
|
# Trapping a signal that cannot be trapped.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1016
|
||||||
|
- SA1016
|
||||||
|
# Channels used with 'os/signal.Notify' should be buffered.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1017
|
||||||
|
- SA1017
|
||||||
|
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1018
|
||||||
|
- SA1018
|
||||||
|
# Using a deprecated function, variable, constant or field.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1019
|
||||||
|
- SA1019
|
||||||
|
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1020
|
||||||
|
- SA1020
|
||||||
|
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1021
|
||||||
|
- SA1021
|
||||||
|
# Modifying the buffer in an 'io.Writer' implementation.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1023
|
||||||
|
- SA1023
|
||||||
|
# A string cutset contains duplicate characters.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1024
|
||||||
|
- SA1024
|
||||||
|
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1025
|
||||||
|
- SA1025
|
||||||
|
# Cannot marshal channels or functions.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1026
|
||||||
|
- SA1026
|
||||||
|
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1027
|
||||||
|
- SA1027
|
||||||
|
# 'sort.Slice' can only be used on slices.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1028
|
||||||
|
- SA1028
|
||||||
|
# Inappropriate key in call to 'context.WithValue'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1029
|
||||||
|
- SA1029
|
||||||
|
# Invalid argument in call to a 'strconv' function.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1030
|
||||||
|
- SA1030
|
||||||
|
# Overlapping byte slices passed to an encoder.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1031
|
||||||
|
- SA1031
|
||||||
|
# Wrong order of arguments to 'errors.Is'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1032
|
||||||
|
- SA1032
|
||||||
|
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2000
|
||||||
|
- SA2000
|
||||||
|
# Empty critical section, did you mean to defer the unlock?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2001
|
||||||
|
- SA2001
|
||||||
|
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2002
|
||||||
|
- SA2002
|
||||||
|
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2003
|
||||||
|
- SA2003
|
||||||
|
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA3000
|
||||||
|
- SA3000
|
||||||
|
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA3001
|
||||||
|
- SA3001
|
||||||
|
# Binary operator has identical expressions on both sides.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4000
|
||||||
|
- SA4000
|
||||||
|
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4001
|
||||||
|
- SA4001
|
||||||
|
# Comparing unsigned values against negative values is pointless.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4003
|
||||||
|
- SA4003
|
||||||
|
# The loop exits unconditionally after one iteration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4004
|
||||||
|
- SA4004
|
||||||
|
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4005
|
||||||
|
- SA4005
|
||||||
|
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4006
|
||||||
|
- SA4006
|
||||||
|
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4008
|
||||||
|
- SA4008
|
||||||
|
# A function argument is overwritten before its first use.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4009
|
||||||
|
- SA4009
|
||||||
|
# The result of 'append' will never be observed anywhere.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4010
|
||||||
|
- SA4010
|
||||||
|
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4011
|
||||||
|
- SA4011
|
||||||
|
# Comparing a value against NaN even though no value is equal to NaN.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4012
|
||||||
|
- SA4012
|
||||||
|
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4013
|
||||||
|
- SA4013
|
||||||
|
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4014
|
||||||
|
- SA4014
|
||||||
|
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4015
|
||||||
|
- SA4015
|
||||||
|
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4016
|
||||||
|
- SA4016
|
||||||
|
# Discarding the return values of a function without side effects, making the call pointless.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4017
|
||||||
|
- SA4017
|
||||||
|
# Self-assignment of variables.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4018
|
||||||
|
- SA4018
|
||||||
|
# Multiple, identical build constraints in the same file.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4019
|
||||||
|
- SA4019
|
||||||
|
# Unreachable case clause in a type switch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4020
|
||||||
|
- SA4020
|
||||||
|
# "x = append(y)" is equivalent to "x = y".
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4021
|
||||||
|
- SA4021
|
||||||
|
# Comparing the address of a variable against nil.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4022
|
||||||
|
- SA4022
|
||||||
|
# Impossible comparison of interface value with untyped nil.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4023
|
||||||
|
- SA4023
|
||||||
|
# Checking for impossible return value from a builtin function.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4024
|
||||||
|
- SA4024
|
||||||
|
# Integer division of literals that results in zero.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4025
|
||||||
|
- SA4025
|
||||||
|
# Go constants cannot express negative zero.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4026
|
||||||
|
- SA4026
|
||||||
|
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4027
|
||||||
|
- SA4027
|
||||||
|
# 'x % 1' is always zero.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4028
|
||||||
|
- SA4028
|
||||||
|
# Ineffective attempt at sorting slice.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4029
|
||||||
|
- SA4029
|
||||||
|
# Ineffective attempt at generating random number.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4030
|
||||||
|
- SA4030
|
||||||
|
# Checking never-nil value against nil.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4031
|
||||||
|
- SA4031
|
||||||
|
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4032
|
||||||
|
- SA4032
|
||||||
|
# Assignment to nil map.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5000
|
||||||
|
- SA5000
|
||||||
|
# Deferring 'Close' before checking for a possible error.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5001
|
||||||
|
- SA5001
|
||||||
|
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5002
|
||||||
|
- SA5002
|
||||||
|
# Defers in infinite loops will never execute.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5003
|
||||||
|
- SA5003
|
||||||
|
# "for { select { ..." with an empty default branch spins.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5004
|
||||||
|
- SA5004
|
||||||
|
# The finalizer references the finalized object, preventing garbage collection.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5005
|
||||||
|
- SA5005
|
||||||
|
# Infinite recursive call.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5007
|
||||||
|
- SA5007
|
||||||
|
# Invalid struct tag.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5008
|
||||||
|
- SA5008
|
||||||
|
# Invalid Printf call.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5009
|
||||||
|
- SA5009
|
||||||
|
# Impossible type assertion.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5010
|
||||||
|
- SA5010
|
||||||
|
# Possible nil pointer dereference.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5011
|
||||||
|
- SA5011
|
||||||
|
# Passing odd-sized slice to function expecting even size.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5012
|
||||||
|
- SA5012
|
||||||
|
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6000
|
||||||
|
- SA6000
|
||||||
|
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6001
|
||||||
|
- SA6001
|
||||||
|
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6002
|
||||||
|
- SA6002
|
||||||
|
# Converting a string to a slice of runes before ranging over it.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6003
|
||||||
|
- SA6003
|
||||||
|
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6005
|
||||||
|
- SA6005
|
||||||
|
# Using io.WriteString to write '[]byte'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6006
|
||||||
|
- SA6006
|
||||||
|
# Defers in range loops may not run when you expect them to.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9001
|
||||||
|
- SA9001
|
||||||
|
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9002
|
||||||
|
- SA9002
|
||||||
|
# Empty body in an if or else branch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9003
|
||||||
|
- SA9003
|
||||||
|
# Only the first constant has an explicit type.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9004
|
||||||
|
- SA9004
|
||||||
|
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9005
|
||||||
|
- SA9005
|
||||||
|
# Dubious bit shifting of a fixed size integer value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9006
|
||||||
|
- SA9006
|
||||||
|
# Deleting a directory that shouldn't be deleted.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9007
|
||||||
|
- SA9007
|
||||||
|
# 'else' branch of a type assertion is probably not reading the right value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9008
|
||||||
|
- SA9008
|
||||||
|
# Ineffectual Go compiler directive.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9009
|
||||||
|
- SA9009
|
||||||
|
# Incorrect or missing package comment.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1000
|
||||||
|
- ST1000
|
||||||
|
# Dot imports are discouraged.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1001
|
||||||
|
- ST1001
|
||||||
|
# Poorly chosen identifier.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1003
|
||||||
|
- ST1003
|
||||||
|
# Incorrectly formatted error string.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1005
|
||||||
|
- ST1005
|
||||||
|
# Poorly chosen receiver name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1006
|
||||||
|
- ST1006
|
||||||
|
# A function's error value should be its last return value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1008
|
||||||
|
- ST1008
|
||||||
|
# Poorly chosen name for variable of type 'time.Duration'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1011
|
||||||
|
- ST1011
|
||||||
|
# Poorly chosen name for error variable.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1012
|
||||||
|
- ST1012
|
||||||
|
# Should use constants for HTTP error codes, not magic numbers.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1013
|
||||||
|
- ST1013
|
||||||
|
# A switch's default case should be the first or last case.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1015
|
||||||
|
- ST1015
|
||||||
|
# Use consistent method receiver names.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1016
|
||||||
|
- ST1016
|
||||||
|
# Don't use Yoda conditions.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1017
|
||||||
|
- ST1017
|
||||||
|
# Avoid zero-width and control characters in string literals.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1018
|
||||||
|
- ST1018
|
||||||
|
# Importing the same package multiple times.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1019
|
||||||
|
- ST1019
|
||||||
|
# The documentation of an exported function should start with the function's name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1020
|
||||||
|
- ST1020
|
||||||
|
# The documentation of an exported type should start with type's name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1021
|
||||||
|
- ST1021
|
||||||
|
# The documentation of an exported variable or constant should start with variable's name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1022
|
||||||
|
- ST1022
|
||||||
|
# Redundant type in variable declaration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1023
|
||||||
|
- ST1023
|
||||||
|
# Use plain channel send or receive instead of single-case select.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1000
|
||||||
|
- S1000
|
||||||
|
# Replace for loop with call to copy.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1001
|
||||||
|
- S1001
|
||||||
|
# Omit comparison with boolean constant.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1002
|
||||||
|
- S1002
|
||||||
|
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1003
|
||||||
|
- S1003
|
||||||
|
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1004
|
||||||
|
- S1004
|
||||||
|
# Drop unnecessary use of the blank identifier.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1005
|
||||||
|
- S1005
|
||||||
|
# Use "for { ... }" for infinite loops.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1006
|
||||||
|
- S1006
|
||||||
|
# Simplify regular expression by using raw string literal.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1007
|
||||||
|
- S1007
|
||||||
|
# Simplify returning boolean expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1008
|
||||||
|
- S1008
|
||||||
|
# Omit redundant nil check on slices, maps, and channels.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1009
|
||||||
|
- S1009
|
||||||
|
# Omit default slice index.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1010
|
||||||
|
- S1010
|
||||||
|
# Use a single 'append' to concatenate two slices.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1011
|
||||||
|
- S1011
|
||||||
|
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1012
|
||||||
|
- S1012
|
||||||
|
# Use a type conversion instead of manually copying struct fields.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1016
|
||||||
|
- S1016
|
||||||
|
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1017
|
||||||
|
- S1017
|
||||||
|
# Use "copy" for sliding elements.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1018
|
||||||
|
- S1018
|
||||||
|
# Simplify "make" call by omitting redundant arguments.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1019
|
||||||
|
- S1019
|
||||||
|
# Omit redundant nil check in type assertion.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1020
|
||||||
|
- S1020
|
||||||
|
# Merge variable declaration and assignment.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1021
|
||||||
|
- S1021
|
||||||
|
# Omit redundant control flow.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1023
|
||||||
|
- S1023
|
||||||
|
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1024
|
||||||
|
- S1024
|
||||||
|
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1025
|
||||||
|
- S1025
|
||||||
|
# Simplify error construction with 'fmt.Errorf'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1028
|
||||||
|
- S1028
|
||||||
|
# Range over the string directly.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1029
|
||||||
|
- S1029
|
||||||
|
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1030
|
||||||
|
- S1030
|
||||||
|
# Omit redundant nil check around loop.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1031
|
||||||
|
- S1031
|
||||||
|
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1032
|
||||||
|
- S1032
|
||||||
|
# Unnecessary guard around call to "delete".
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1033
|
||||||
|
- S1033
|
||||||
|
# Use result of type assertion to simplify cases.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1034
|
||||||
|
- S1034
|
||||||
|
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1035
|
||||||
|
- S1035
|
||||||
|
# Unnecessary guard around map access.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1036
|
||||||
|
- S1036
|
||||||
|
# Elaborate way of sleeping.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1037
|
||||||
|
- S1037
|
||||||
|
# Unnecessarily complex way of printing formatted string.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1038
|
||||||
|
- S1038
|
||||||
|
# Unnecessary use of 'fmt.Sprint'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1039
|
||||||
|
- S1039
|
||||||
|
# Type assertion to current type.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1040
|
||||||
|
- S1040
|
||||||
|
# Apply De Morgan's law.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1001
|
||||||
|
- QF1001
|
||||||
|
# Convert untagged switch to tagged switch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1002
|
||||||
|
- QF1002
|
||||||
|
# Convert if/else-if chain to tagged switch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1003
|
||||||
|
- QF1003
|
||||||
|
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1004
|
||||||
|
- QF1004
|
||||||
|
# Expand call to 'math.Pow'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1005
|
||||||
|
- QF1005
|
||||||
|
# Lift 'if'+'break' into loop condition.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1006
|
||||||
|
- QF1006
|
||||||
|
# Merge conditional assignment into variable declaration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1007
|
||||||
|
- QF1007
|
||||||
|
# Omit embedded fields from selector expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1008
|
||||||
|
- QF1008
|
||||||
|
# Use 'time.Time.Equal' instead of '==' operator.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1009
|
||||||
|
- QF1009
|
||||||
|
# Convert slice of bytes to string when printing it.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1010
|
||||||
|
- QF1010
|
||||||
|
# Omit redundant type from variable declaration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1011
|
||||||
|
- QF1011
|
||||||
|
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1012
|
||||||
|
- QF1012
|
||||||
|
unused:
|
||||||
|
# Mark all struct fields that have been written to as used.
|
||||||
|
# Default: true
|
||||||
|
field-writes-are-uses: false
|
||||||
|
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||||
|
# Default: false
|
||||||
|
post-statements-are-reads: true
|
||||||
|
# Mark all exported fields as used.
|
||||||
|
# default: true
|
||||||
|
exported-fields-are-used: false
|
||||||
|
# Mark all function parameters as used.
|
||||||
|
# default: true
|
||||||
|
parameters-are-used: true
|
||||||
|
# Mark all local variables as used.
|
||||||
|
# default: true
|
||||||
|
local-variables-are-used: false
|
||||||
|
# Mark all identifiers inside generated files as used.
|
||||||
|
# Default: true
|
||||||
|
generated-is-used: false
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- gofmt
|
||||||
|
settings:
|
||||||
|
gofmt:
|
||||||
|
# Simplify code: gofmt with `-s` option.
|
||||||
|
# Default: true
|
||||||
|
simplify: false
|
||||||
|
# Apply the rewrite rules to the source before reformatting.
|
||||||
|
# https://pkg.go.dev/cmd/gofmt
|
||||||
|
# Default: []
|
||||||
|
rewrite-rules:
|
||||||
|
- pattern: 'interface{}'
|
||||||
|
replacement: 'any'
|
||||||
|
- pattern: 'a[b:len(a)]'
|
||||||
|
replacement: 'a[b:]'
|
||||||
@@ -1,6 +1,16 @@
|
|||||||
.PHONY: wire
|
.PHONY: wire build build-embed
|
||||||
|
|
||||||
wire:
|
wire:
|
||||||
@echo "生成 Wire 代码..."
|
@echo "生成 Wire 代码..."
|
||||||
@cd cmd/server && go generate
|
@cd cmd/server && go generate
|
||||||
@echo "Wire 代码生成完成"
|
@echo "Wire 代码生成完成"
|
||||||
|
|
||||||
|
build:
|
||||||
|
@echo "构建后端(不嵌入前端)..."
|
||||||
|
@go build -o bin/server ./cmd/server
|
||||||
|
@echo "构建完成: bin/server"
|
||||||
|
|
||||||
|
build-embed:
|
||||||
|
@echo "构建后端(嵌入前端)..."
|
||||||
|
@go build -tags embed -o bin/server ./cmd/server
|
||||||
|
@echo "构建完成: bin/server (with embedded frontend)"
|
||||||
@@ -48,7 +48,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||||
authHandler := handler.NewAuthHandler(authService)
|
authHandler := handler.NewAuthHandler(authService)
|
||||||
userService := service.NewUserService(userRepository, configConfig)
|
userService := service.NewUserService(userRepository)
|
||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||||
groupRepository := repository.NewGroupRepository(db)
|
groupRepository := repository.NewGroupRepository(db)
|
||||||
@@ -67,22 +67,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
|
dashboardHandler := admin.NewDashboardHandler(usageLogRepository)
|
||||||
accountRepository := repository.NewAccountRepository(db)
|
accountRepository := repository.NewAccountRepository(db)
|
||||||
proxyRepository := repository.NewProxyRepository(db)
|
proxyRepository := repository.NewProxyRepository(db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
proxyHandler := admin.NewProxyHandler(adminService)
|
proxyHandler := admin.NewProxyHandler(adminService)
|
||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||||
@@ -103,16 +103,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityCache := repository.NewIdentityCache(client)
|
identityCache := repository.NewIdentityCache(client)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||||
groupService := service.NewGroupService(groupRepository)
|
groupService := service.NewGroupService(groupRepository)
|
||||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||||
proxyService := service.NewProxyService(proxyRepository)
|
proxyService := service.NewProxyService(proxyRepository)
|
||||||
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||||
services := &service.Services{
|
services := &service.Services{
|
||||||
Auth: authService,
|
Auth: authService,
|
||||||
User: userService,
|
User: userService,
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ type PricingConfig struct {
|
|||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
Mode string `mapstructure:"mode"` // debug/release
|
Mode string `mapstructure:"mode"` // debug/release
|
||||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("server.port", 8080)
|
viper.SetDefault("server.port", 8080)
|
||||||
viper.SetDefault("server.mode", "debug")
|
viper.SetDefault("server.mode", "debug")
|
||||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||||
|
|
||||||
// Database
|
// Database
|
||||||
viper.SetDefault("database.host", "localhost")
|
viper.SetDefault("database.host", "localhost")
|
||||||
@@ -210,10 +210,10 @@ func setDefaults() {
|
|||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
viper.SetDefault("token_refresh.enabled", true)
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
|
|||||||
@@ -13,14 +13,12 @@ import (
|
|||||||
// OAuthHandler handles OAuth-related operations for accounts
|
// OAuthHandler handles OAuth-related operations for accounts
|
||||||
type OAuthHandler struct {
|
type OAuthHandler struct {
|
||||||
oauthService *service.OAuthService
|
oauthService *service.OAuthService
|
||||||
adminService service.AdminService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthHandler creates a new OAuth handler
|
// NewOAuthHandler creates a new OAuth handler
|
||||||
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
|
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||||
return &OAuthHandler{
|
return &OAuthHandler{
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
adminService: adminService,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,29 +44,29 @@ func NewAccountHandler(adminService service.AdminService, oauthService *service.
|
|||||||
|
|
||||||
// CreateAccountRequest represents create account request
|
// CreateAccountRequest represents create account request
|
||||||
type CreateAccountRequest struct {
|
type CreateAccountRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||||
Credentials map[string]interface{} `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]interface{} `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountRequest represents update account request
|
// UpdateAccountRequest represents update account request
|
||||||
// 使用指针类型来区分"未提供"和"设置为0"
|
// 使用指针类型来区分"未提供"和"设置为0"
|
||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||||
Credentials map[string]interface{} `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]interface{} `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all accounts with pagination
|
// List handles listing all accounts with pagination
|
||||||
@@ -242,7 +240,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||||
newCredentials := make(map[string]interface{})
|
newCredentials := make(map[string]any)
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
newCredentials[k] = v
|
newCredentials[k] = v
|
||||||
}
|
}
|
||||||
@@ -573,7 +571,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
|
|
||||||
// For API Key accounts: return models based on model_mapping
|
// For API Key accounts: return models based on model_mapping
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
// No mapping configured, return default models
|
// No mapping configured, return default models
|
||||||
response.Success(c, claude.DefaultModels)
|
response.Success(c, claude.DefaultModels)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"sub2api/internal/pkg/response"
|
"sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/pkg/timezone"
|
"sub2api/internal/pkg/timezone"
|
||||||
"sub2api/internal/repository"
|
"sub2api/internal/repository"
|
||||||
"sub2api/internal/service"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,17 +12,15 @@ import (
|
|||||||
|
|
||||||
// DashboardHandler handles admin dashboard statistics
|
// DashboardHandler handles admin dashboard statistics
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
adminService service.AdminService
|
usageRepo *repository.UsageLogRepository
|
||||||
usageRepo *repository.UsageLogRepository
|
startTime time.Time // Server start time for uptime calculation
|
||||||
startTime time.Time // Server start time for uptime calculation
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDashboardHandler creates a new admin dashboard handler
|
// NewDashboardHandler creates a new admin dashboard handler
|
||||||
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
func NewDashboardHandler(usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||||
return &DashboardHandler{
|
return &DashboardHandler{
|
||||||
adminService: adminService,
|
usageRepo: usageRepo,
|
||||||
usageRepo: usageRepo,
|
startTime: time.Now(),
|
||||||
startTime: time.Now(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,7 +255,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.UserIDs) == 0 {
|
if len(req.UserIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,7 +283,7 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.ApiKeyIDs) == 0 {
|
if len(req.ApiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
|||||||
response.Paginated(c, accounts, total, page, pageSize)
|
response.Paginated(c, accounts, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||||
type BatchCreateProxyItem struct {
|
type BatchCreateProxyItem struct {
|
||||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||||
|
|||||||
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
|||||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||||
// Return mock data for now
|
// Return mock data for now
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"total_codes": 0,
|
"total_codes": 0,
|
||||||
"active_codes": 0,
|
"active_codes": 0,
|
||||||
"used_codes": 0,
|
"used_codes": 0,
|
||||||
"expired_codes": 0,
|
"expired_codes": 0,
|
||||||
"total_value_distributed": 0.0,
|
"total_value_distributed": 0.0,
|
||||||
"by_type": gin.H{
|
"by_type": gin.H{
|
||||||
"balance": 0,
|
"balance": 0,
|
||||||
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
writer := csv.NewWriter(&buf)
|
writer := csv.NewWriter(&buf)
|
||||||
|
|
||||||
// Write header
|
// Write header
|
||||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Write data rows
|
// Write data rows
|
||||||
for _, code := range codes {
|
for _, code := range codes {
|
||||||
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
if code.UsedAt != nil {
|
if code.UsedAt != nil {
|
||||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
writer.Write([]string{
|
if err := writer.Write([]string{
|
||||||
fmt.Sprintf("%d", code.ID),
|
fmt.Sprintf("%d", code.ID),
|
||||||
code.Code,
|
code.Code,
|
||||||
code.Type,
|
code.Type,
|
||||||
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
usedBy,
|
usedBy,
|
||||||
usedAt,
|
usedAt,
|
||||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
})
|
}); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Flush()
|
writer.Flush()
|
||||||
|
if err := writer.Error(); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Header("Content-Type", "text/csv")
|
c.Header("Content-Type", "text/csv")
|
||||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||||
keyword := c.Query("q")
|
keyword := c.Query("q")
|
||||||
if keyword == "" {
|
if keyword == "" {
|
||||||
response.Success(c, []interface{}{})
|
response.Success(c, []any{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -268,7 +268,9 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
|
|||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
*streamStarted = true
|
*streamStarted = true
|
||||||
}
|
}
|
||||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
if _, err := fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,7 +416,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
if ok {
|
if ok {
|
||||||
// Send error event in SSE format
|
// Send error event in SSE format
|
||||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||||
fmt.Fprint(c.Writer, errorEvent)
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -574,11 +578,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
|||||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"id": "msg_mock_warmup",
|
"id": "msg_mock_warmup",
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": model,
|
"model": model,
|
||||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||||
"stop_reason": "end_turn",
|
"stop_reason": "end_turn",
|
||||||
"usage": gin.H{
|
"usage": gin.H{
|
||||||
"input_tokens": 10,
|
"input_tokens": 10,
|
||||||
|
|||||||
@@ -358,7 +358,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.ApiKeyIDs) == 0 {
|
if len(req.ApiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,7 +383,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(validApiKeyIDs) == 0 {
|
if len(validApiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// JSONB 用于存储JSONB数据
|
// JSONB 用于存储JSONB数据
|
||||||
type JSONB map[string]interface{}
|
type JSONB map[string]any
|
||||||
|
|
||||||
func (j JSONB) Value() (driver.Value, error) {
|
func (j JSONB) Value() (driver.Value, error) {
|
||||||
if j == nil {
|
if j == nil {
|
||||||
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
|
|||||||
return json.Marshal(j)
|
return json.Marshal(j)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *JSONB) Scan(value interface{}) error {
|
func (j *JSONB) Scan(value any) error {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
*j = nil
|
*j = nil
|
||||||
return nil
|
return nil
|
||||||
@@ -40,8 +40,8 @@ type Account struct {
|
|||||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||||
@@ -145,7 +145,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 处理map[string]interface{}类型
|
// 处理map[string]interface{}类型
|
||||||
if m, ok := raw.(map[string]interface{}); ok {
|
if m, ok := raw.(map[string]any); ok {
|
||||||
result := make(map[string]string)
|
result := make(map[string]string)
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
@@ -163,7 +163,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
// 如果没有设置模型映射,则支持所有模型
|
// 如果没有设置模型映射,则支持所有模型
|
||||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return true // 没有映射配置,支持所有模型
|
return true // 没有映射配置,支持所有模型
|
||||||
}
|
}
|
||||||
_, exists := mapping[requestedModel]
|
_, exists := mapping[requestedModel]
|
||||||
@@ -174,7 +174,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
// 如果没有映射,返回原始模型名
|
// 如果没有映射,返回原始模型名
|
||||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return requestedModel
|
return requestedModel
|
||||||
}
|
}
|
||||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||||
@@ -231,7 +231,7 @@ func (a *Account) GetCustomErrorCodes() []int {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||||
if arr, ok := raw.([]interface{}); ok {
|
if arr, ok := raw.([]any); ok {
|
||||||
result := make([]int, 0, len(arr))
|
result := make([]int, 0, len(arr))
|
||||||
for _, v := range arr {
|
for _, v := range arr {
|
||||||
// JSON 数字默认解析为 float64
|
// JSON 数字默认解析为 float64
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||||
Description string `gorm:"type:text" json:"description"`
|
Description string `gorm:"type:text" json:"description"`
|
||||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||||
|
|
||||||
// 订阅功能字段
|
// 订阅功能字段
|
||||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import (
|
|||||||
type RedeemCode struct {
|
type RedeemCode struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||||
UsedAt *time.Time `json:"used_at"`
|
UsedAt *time.Time `json:"used_at"`
|
||||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||||
|
|
||||||
// 订阅类型专用字段
|
// 订阅类型专用字段
|
||||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||||
|
|
||||||
// 关联
|
// 关联
|
||||||
@@ -40,8 +40,10 @@ func (r *RedeemCode) CanUse() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRedeemCode 生成唯一的兑换码
|
// GenerateRedeemCode 生成唯一的兑换码
|
||||||
func GenerateRedeemCode() string {
|
func GenerateRedeemCode() (string, error) {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
rand.Read(b)
|
if _, err := rand.Read(b); err != nil {
|
||||||
return hex.EncodeToString(b)
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,17 +19,17 @@ func (Setting) TableName() string {
|
|||||||
// 设置Key常量
|
// 设置Key常量
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||||
|
|
||||||
// Cloudflare Turnstile 设置
|
// Cloudflare Turnstile 设置
|
||||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type UsageLog struct {
|
|||||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||||
|
|||||||
@@ -9,22 +9,22 @@ import (
|
|||||||
|
|
||||||
// Response 标准API响应格式
|
// Response 标准API响应格式
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Data interface{} `json:"data,omitempty"`
|
Data any `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaginatedData 分页数据格式(匹配前端期望)
|
// PaginatedData 分页数据格式(匹配前端期望)
|
||||||
type PaginatedData struct {
|
type PaginatedData struct {
|
||||||
Items interface{} `json:"items"`
|
Items any `json:"items"`
|
||||||
Total int64 `json:"total"`
|
Total int64 `json:"total"`
|
||||||
Page int `json:"page"`
|
Page int `json:"page"`
|
||||||
PageSize int `json:"page_size"`
|
PageSize int `json:"page_size"`
|
||||||
Pages int `json:"pages"`
|
Pages int `json:"pages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success 返回成功响应
|
// Success 返回成功响应
|
||||||
func Success(c *gin.Context, data interface{}) {
|
func Success(c *gin.Context, data any) {
|
||||||
c.JSON(http.StatusOK, Response{
|
c.JSON(http.StatusOK, Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
Message: "success",
|
Message: "success",
|
||||||
@@ -33,7 +33,7 @@ func Success(c *gin.Context, data interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Created 返回创建成功响应
|
// Created 返回创建成功响应
|
||||||
func Created(c *gin.Context, data interface{}) {
|
func Created(c *gin.Context, data any) {
|
||||||
c.JSON(http.StatusCreated, Response{
|
c.JSON(http.StatusCreated, Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
Message: "success",
|
Message: "success",
|
||||||
@@ -75,7 +75,7 @@ func InternalError(c *gin.Context, message string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Paginated 返回分页数据
|
// Paginated 返回分页数据
|
||||||
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
|
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||||
if pages < 1 {
|
if pages < 1 {
|
||||||
pages = 1
|
pages = 1
|
||||||
@@ -99,7 +99,7 @@ type PaginationResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||||
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
|
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||||
if pagination == nil {
|
if pagination == nil {
|
||||||
Success(c, PaginatedData{
|
Success(c, PaginatedData{
|
||||||
Items: items,
|
Items: items,
|
||||||
|
|||||||
@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
|
|||||||
|
|
||||||
func TestTimeNowAffected(t *testing.T) {
|
func TestTimeNowAffected(t *testing.T) {
|
||||||
// Reset to UTC first
|
// Reset to UTC first
|
||||||
Init("UTC")
|
if err := Init("UTC"); err != nil {
|
||||||
|
t.Fatalf("Init failed with UTC: %v", err)
|
||||||
|
}
|
||||||
utcNow := time.Now()
|
utcNow := time.Now()
|
||||||
|
|
||||||
// Switch to Shanghai (UTC+8)
|
// Switch to Shanghai (UTC+8)
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
shanghaiNow := time.Now()
|
shanghaiNow := time.Now()
|
||||||
|
|
||||||
// The times should be the same instant, but different timezone representation
|
// The times should be the same instant, but different timezone representation
|
||||||
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToday(t *testing.T) {
|
func TestToday(t *testing.T) {
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
today := Today()
|
today := Today()
|
||||||
now := Now()
|
now := Now()
|
||||||
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStartOfDay(t *testing.T) {
|
func TestStartOfDay(t *testing.T) {
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a time at 15:30:45
|
// Create a time at 15:30:45
|
||||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||||
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
|
|||||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||||
// and why StartOfDay is more reliable for timezone-aware code
|
// and why StartOfDay is more reliable for timezone-aware code
|
||||||
|
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
now := Now()
|
now := Now()
|
||||||
|
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
|||||||
|
|
||||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": model.StatusError,
|
"status": model.StatusError,
|
||||||
"error_message": errorMsg,
|
"error_message": errorMsg,
|
||||||
}).Error
|
}).Error
|
||||||
@@ -226,7 +226,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
|||||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"rate_limited_at": now,
|
"rate_limited_at": now,
|
||||||
"rate_limit_reset_at": resetAt,
|
"rate_limit_reset_at": resetAt,
|
||||||
}).Error
|
}).Error
|
||||||
@@ -241,7 +241,7 @@ func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
|||||||
// ClearRateLimit 清除账号的限流状态
|
// ClearRateLimit 清除账号的限流状态
|
||||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"rate_limited_at": nil,
|
"rate_limited_at": nil,
|
||||||
"rate_limit_reset_at": nil,
|
"rate_limit_reset_at": nil,
|
||||||
"overload_until": nil,
|
"overload_until": nil,
|
||||||
@@ -250,7 +250,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
|||||||
|
|
||||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
updates := map[string]interface{}{
|
updates := map[string]any{
|
||||||
"session_window_status": status,
|
"session_window_status": status,
|
||||||
}
|
}
|
||||||
if start != nil {
|
if start != nil {
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
|||||||
|
|
||||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
fields := map[string]interface{}{
|
fields := map[string]any{
|
||||||
subFieldStatus: data.Status,
|
subFieldStatus: data.Status,
|
||||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||||
subFieldDailyUsage: data.DailyUsage,
|
subFieldDailyUsage: data.DailyUsage,
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
|||||||
|
|
||||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||||
|
|
||||||
reqBody := map[string]interface{}{
|
reqBody := map[string]any{
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"client_id": oauth.ClientID,
|
"client_id": oauth.ClientID,
|
||||||
"organization_uuid": orgUUID,
|
"organization_uuid": orgUUID,
|
||||||
@@ -155,7 +155,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reqBody := map[string]interface{}{
|
reqBody := map[string]any{
|
||||||
"code": authCode,
|
"code": authCode,
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"client_id": oauth.ClientID,
|
"client_id": oauth.ClientID,
|
||||||
|
|||||||
@@ -19,7 +19,11 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to get default transport")
|
||||||
|
}
|
||||||
|
transport = transport.Clone()
|
||||||
if proxyURL != "" {
|
if proxyURL != "" {
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
transport.Proxy = http.ProxyURL(parsedURL)
|
||||||
@@ -43,7 +47,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||||
@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||||
@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer func() { _ = out.Close() }()
|
||||||
|
|
||||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||||
@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
|
|
||||||
// Check if we hit the limit (downloaded more than maxSize)
|
// Check if we hit the limit (downloaded more than maxSize)
|
||||||
if written > maxSize {
|
if written > maxSize {
|
||||||
os.Remove(dest) // Clean up partial file
|
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
latencyMs := time.Since(startTime).Milliseconds()
|
latencyMs := time.Since(startTime).Milliseconds()
|
||||||
|
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": model.StatusUsed,
|
"status": model.StatusUsed,
|
||||||
"used_by": userID,
|
"used_by": userID,
|
||||||
"used_at": now,
|
"used_at": now,
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("send request: %w", err)
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
var result service.TurnstileVerifyResponse
|
var result service.TurnstileVerifyResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
|
|||||||
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||||
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
||||||
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
||||||
@@ -197,7 +197,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
|||||||
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"daily_usage_usd": 0,
|
"daily_usage_usd": 0,
|
||||||
"daily_window_start": newWindowStart,
|
"daily_window_start": newWindowStart,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
@@ -208,7 +208,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
|
|||||||
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"weekly_usage_usd": 0,
|
"weekly_usage_usd": 0,
|
||||||
"weekly_window_start": newWindowStart,
|
"weekly_window_start": newWindowStart,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
@@ -219,7 +219,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
|
|||||||
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"monthly_usage_usd": 0,
|
"monthly_usage_usd": 0,
|
||||||
"monthly_window_start": newWindowStart,
|
"monthly_window_start": newWindowStart,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
@@ -230,7 +230,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
|||||||
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"daily_window_start": activateTime,
|
"daily_window_start": activateTime,
|
||||||
"weekly_window_start": activateTime,
|
"weekly_window_start": activateTime,
|
||||||
"monthly_window_start": activateTime,
|
"monthly_window_start": activateTime,
|
||||||
@@ -242,7 +242,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
|
|||||||
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": status,
|
"status": status,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
@@ -252,7 +252,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
|
|||||||
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"expires_at": newExpiresAt,
|
"expires_at": newExpiresAt,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
@@ -262,7 +262,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
|
|||||||
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"notes": notes,
|
"notes": notes,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
@@ -281,7 +281,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
|
|||||||
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": model.SubscriptionStatusExpired,
|
"status": model.SubscriptionStatusExpired,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -17,27 +17,27 @@ var (
|
|||||||
|
|
||||||
// CreateAccountRequest 创建账号请求
|
// CreateAccountRequest 创建账号请求
|
||||||
type CreateAccountRequest struct {
|
type CreateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Credentials map[string]interface{} `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]interface{} `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountRequest 更新账号请求
|
// UpdateAccountRequest 更新账号请求
|
||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name *string `json:"name"`
|
Name *string `json:"name"`
|
||||||
Credentials *map[string]interface{} `json:"credentials"`
|
Credentials *map[string]any `json:"credentials"`
|
||||||
Extra *map[string]interface{} `json:"extra"`
|
Extra *map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
Status *string `json:"status"`
|
Status *string `json:"status"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountService 账号管理服务
|
// AccountService 账号管理服务
|
||||||
|
|||||||
@@ -51,22 +51,29 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionString generates a Claude Code style session string
|
// generateSessionString generates a Claude Code style session string
|
||||||
func generateSessionString() string {
|
func generateSessionString() (string, error) {
|
||||||
bytes := make([]byte, 32)
|
bytes := make([]byte, 32)
|
||||||
rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
hex64 := hex.EncodeToString(bytes)
|
hex64 := hex.EncodeToString(bytes)
|
||||||
sessionUUID := uuid.New().String()
|
sessionUUID := uuid.New().String()
|
||||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTestPayload creates a Claude Code style test request payload
|
// createTestPayload creates a Claude Code style test request payload
|
||||||
func createTestPayload(modelID string) map[string]interface{} {
|
func createTestPayload(modelID string) (map[string]any, error) {
|
||||||
return map[string]interface{}{
|
sessionID, err := generateSessionString()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
"model": modelID,
|
"model": modelID,
|
||||||
"messages": []map[string]interface{}{
|
"messages": []map[string]any{
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": []map[string]interface{}{
|
"content": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "hi",
|
"text": "hi",
|
||||||
@@ -77,7 +84,7 @@ func createTestPayload(modelID string) map[string]interface{} {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"system": []map[string]interface{}{
|
"system": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||||
@@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"metadata": map[string]string{
|
"metadata": map[string]string{
|
||||||
"user_id": generateSessionString(),
|
"user_id": sessionID,
|
||||||
},
|
},
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"temperature": 1,
|
"temperature": 1,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAccountConnection tests an account's connection by sending a test request
|
// TestAccountConnection tests an account's connection by sending a test request
|
||||||
@@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
// For API Key accounts with model mapping, map the model
|
// For API Key accounts with model mapping, map the model
|
||||||
if account.Type == "apikey" {
|
if account.Type == "apikey" {
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if mapping != nil && len(mapping) > 0 {
|
if len(mapping) > 0 {
|
||||||
if mappedModel, exists := mapping[testModelID]; exists {
|
if mappedModel, exists := mapping[testModelID]; exists {
|
||||||
testModelID = mappedModel
|
testModelID = mappedModel
|
||||||
}
|
}
|
||||||
@@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
// Create Claude Code style payload (same for all account types)
|
// Create Claude Code style payload (same for all account types)
|
||||||
payload := createTestPayload(testModelID)
|
payload, err := createTestPayload(testModelID)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||||
|
}
|
||||||
payloadBytes, _ := json.Marshal(payload)
|
payloadBytes, _ := json.Marshal(payload)
|
||||||
|
|
||||||
// Send test_start event
|
// Send test_start event
|
||||||
@@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -252,7 +262,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var data map[string]interface{}
|
var data map[string]any
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -261,7 +271,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
|
|
||||||
switch eventType {
|
switch eventType {
|
||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
if delta, ok := data["delta"].(map[string]interface{}); ok {
|
if delta, ok := data["delta"].(map[string]any); ok {
|
||||||
if text, ok := delta["text"].(string); ok {
|
if text, ok := delta["text"].(string); ok {
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||||
}
|
}
|
||||||
@@ -271,7 +281,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
return nil
|
return nil
|
||||||
case "error":
|
case "error":
|
||||||
errorMsg := "Unknown error"
|
errorMsg := "Unknown error"
|
||||||
if errData, ok := data["error"].(map[string]interface{}); ok {
|
if errData, ok := data["error"].(map[string]any); ok {
|
||||||
if msg, ok := errData["message"].(string); ok {
|
if msg, ok := errData["message"].(string); ok {
|
||||||
errorMsg = msg
|
errorMsg = msg
|
||||||
}
|
}
|
||||||
@@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
// sendEvent sends a SSE event to the client
|
// sendEvent sends a SSE event to the client
|
||||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||||
eventJSON, _ := json.Marshal(event)
|
eventJSON, _ := json.Marshal(event)
|
||||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||||
|
log.Printf("failed to write SSE event: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,16 +70,14 @@ type ClaudeUsageFetcher interface {
|
|||||||
type AccountUsageService struct {
|
type AccountUsageService struct {
|
||||||
accountRepo ports.AccountRepository
|
accountRepo ports.AccountRepository
|
||||||
usageLogRepo ports.UsageLogRepository
|
usageLogRepo ports.UsageLogRepository
|
||||||
oauthService *OAuthService
|
|
||||||
usageFetcher ClaudeUsageFetcher
|
usageFetcher ClaudeUsageFetcher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||||
return &AccountUsageService{
|
return &AccountUsageService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
oauthService: oauthService,
|
|
||||||
usageFetcher: usageFetcher,
|
usageFetcher: usageFetcher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,8 +96,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
if account.CanGetUsage() {
|
if account.CanGetUsage() {
|
||||||
// 检查缓存
|
// 检查缓存
|
||||||
if cached, ok := usageCacheMap.Load(accountID); ok {
|
if cached, ok := usageCacheMap.Load(accountID); ok {
|
||||||
cache := cached.(*usageCache)
|
cache, ok := cached.(*usageCache)
|
||||||
if time.Since(cache.timestamp) < cacheTTL {
|
if !ok {
|
||||||
|
usageCacheMap.Delete(accountID)
|
||||||
|
} else if time.Since(cache.timestamp) < cacheTTL {
|
||||||
return cache.data, nil
|
return cache.data, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
@@ -23,7 +24,7 @@ type AdminService interface {
|
|||||||
DeleteUser(ctx context.Context, id int64) error
|
DeleteUser(ctx context.Context, id int64) error
|
||||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
||||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
|
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||||
|
|
||||||
// Group management
|
// Group management
|
||||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
||||||
@@ -113,8 +114,8 @@ type CreateAccountInput struct {
|
|||||||
Name string
|
Name string
|
||||||
Platform string
|
Platform string
|
||||||
Type string
|
Type string
|
||||||
Credentials map[string]interface{}
|
Credentials map[string]any
|
||||||
Extra map[string]interface{}
|
Extra map[string]any
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
Priority int
|
Priority int
|
||||||
@@ -124,8 +125,8 @@ type CreateAccountInput struct {
|
|||||||
type UpdateAccountInput struct {
|
type UpdateAccountInput struct {
|
||||||
Name string
|
Name string
|
||||||
Type string // Account type: oauth, setup-token, apikey
|
Type string // Account type: oauth, setup-token, apikey
|
||||||
Credentials map[string]interface{}
|
Credentials map[string]any
|
||||||
Extra map[string]interface{}
|
Extra map[string]any
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||||
@@ -192,8 +193,6 @@ type adminServiceImpl struct {
|
|||||||
proxyRepo ports.ProxyRepository
|
proxyRepo ports.ProxyRepository
|
||||||
apiKeyRepo ports.ApiKeyRepository
|
apiKeyRepo ports.ApiKeyRepository
|
||||||
redeemCodeRepo ports.RedeemCodeRepository
|
redeemCodeRepo ports.RedeemCodeRepository
|
||||||
usageLogRepo ports.UsageLogRepository
|
|
||||||
userSubRepo ports.UserSubscriptionRepository
|
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
}
|
}
|
||||||
@@ -206,8 +205,6 @@ func NewAdminService(
|
|||||||
proxyRepo ports.ProxyRepository,
|
proxyRepo ports.ProxyRepository,
|
||||||
apiKeyRepo ports.ApiKeyRepository,
|
apiKeyRepo ports.ApiKeyRepository,
|
||||||
redeemCodeRepo ports.RedeemCodeRepository,
|
redeemCodeRepo ports.RedeemCodeRepository,
|
||||||
usageLogRepo ports.UsageLogRepository,
|
|
||||||
userSubRepo ports.UserSubscriptionRepository,
|
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
@@ -218,8 +215,6 @@ func NewAdminService(
|
|||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
|
||||||
userSubRepo: userSubRepo,
|
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
}
|
}
|
||||||
@@ -309,7 +304,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
|
||||||
|
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -317,8 +314,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
// Create adjustment records for balance/concurrency changes
|
// Create adjustment records for balance/concurrency changes
|
||||||
balanceDiff := user.Balance - oldBalance
|
balanceDiff := user.Balance - oldBalance
|
||||||
if balanceDiff != 0 {
|
if balanceDiff != 0 {
|
||||||
|
code, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
adjustmentRecord := &model.RedeemCode{
|
adjustmentRecord := &model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: code,
|
||||||
Type: model.AdjustmentTypeAdminBalance,
|
Type: model.AdjustmentTypeAdminBalance,
|
||||||
Value: balanceDiff,
|
Value: balanceDiff,
|
||||||
Status: model.StatusUsed,
|
Status: model.StatusUsed,
|
||||||
@@ -327,15 +329,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
adjustmentRecord.UsedAt = &now
|
adjustmentRecord.UsedAt = &now
|
||||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||||
// Log error but don't fail the update
|
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||||
// The user update has already succeeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||||
if concurrencyDiff != 0 {
|
if concurrencyDiff != 0 {
|
||||||
|
code, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
adjustmentRecord := &model.RedeemCode{
|
adjustmentRecord := &model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: code,
|
||||||
Type: model.AdjustmentTypeAdminConcurrency,
|
Type: model.AdjustmentTypeAdminConcurrency,
|
||||||
Value: float64(concurrencyDiff),
|
Value: float64(concurrencyDiff),
|
||||||
Status: model.StatusUsed,
|
Status: model.StatusUsed,
|
||||||
@@ -344,8 +350,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
adjustmentRecord.UsedAt = &now
|
adjustmentRecord.UsedAt = &now
|
||||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||||
// Log error but don't fail the update
|
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||||
// The user update has already succeeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,7 +393,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||||
|
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -404,9 +411,9 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
|||||||
return keys, result.Total, nil
|
return keys, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
|
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||||
// Return mock data for now
|
// Return mock data for now
|
||||||
return map[string]interface{}{
|
return map[string]any{
|
||||||
"period": period,
|
"period": period,
|
||||||
"total_requests": 0,
|
"total_requests": 0,
|
||||||
"total_cost": 0.0,
|
"total_cost": 0.0,
|
||||||
@@ -579,7 +586,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
for _, userID := range affectedUserIDs {
|
for _, userID := range affectedUserIDs {
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||||
|
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -646,10 +655,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if input.Type != "" {
|
if input.Type != "" {
|
||||||
account.Type = input.Type
|
account.Type = input.Type
|
||||||
}
|
}
|
||||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
if len(input.Credentials) > 0 {
|
||||||
account.Credentials = model.JSONB(input.Credentials)
|
account.Credentials = model.JSONB(input.Credentials)
|
||||||
}
|
}
|
||||||
if input.Extra != nil && len(input.Extra) > 0 {
|
if len(input.Extra) > 0 {
|
||||||
account.Extra = model.JSONB(input.Extra)
|
account.Extra = model.JSONB(input.Extra)
|
||||||
}
|
}
|
||||||
if input.ProxyID != nil {
|
if input.ProxyID != nil {
|
||||||
@@ -831,8 +840,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
|||||||
|
|
||||||
codes := make([]model.RedeemCode, 0, input.Count)
|
codes := make([]model.RedeemCode, 0, input.Count)
|
||||||
for i := 0; i < input.Count; i++ {
|
for i := 0; i < input.Count; i++ {
|
||||||
|
codeValue, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
code := model.RedeemCode{
|
code := model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: codeValue,
|
||||||
Type: input.Type,
|
Type: input.Type,
|
||||||
Value: input.Value,
|
Value: input.Value,
|
||||||
Status: model.StatusUnused,
|
Status: model.StatusUnused,
|
||||||
|
|||||||
@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
|||||||
|
|
||||||
// 检查字符:只允许字母、数字、下划线、连字符
|
// 检查字符:只允许字母、数字、下划线、连字符
|
||||||
for _, c := range key {
|
for _, c := range key {
|
||||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
if (c >= 'a' && c <= 'z') ||
|
||||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
(c >= 'A' && c <= 'Z') ||
|
||||||
return ErrApiKeyInvalidChars
|
(c >= '0' && c <= '9') ||
|
||||||
|
c == '_' || c == '-' {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
return ErrApiKeyInvalidChars
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
|||||||
|
|
||||||
// ValidateToken 验证JWT token并返回用户声明
|
// ValidateToken 验证JWT token并返回用户声明
|
||||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||||
// 验证签名方法
|
// 验证签名方法
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
|||||||
@@ -259,11 +259,11 @@ func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, es
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPricingServiceStatus 获取价格服务状态
|
// GetPricingServiceStatus 获取价格服务状态
|
||||||
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
|
func (s *BillingService) GetPricingServiceStatus() map[string]any {
|
||||||
if s.pricingService != nil {
|
if s.pricingService != nil {
|
||||||
return s.pricingService.GetStatus()
|
return s.pricingService.GetStatus()
|
||||||
}
|
}
|
||||||
return map[string]interface{}{
|
return map[string]any{
|
||||||
"model_count": len(s.fallbackPrices),
|
"model_count": len(s.fallbackPrices),
|
||||||
"last_updated": "using fallback",
|
"last_updated": "using fallback",
|
||||||
"local_hash": "N/A",
|
"local_hash": "N/A",
|
||||||
|
|||||||
@@ -9,12 +9,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Wait polling interval
|
|
||||||
waitPollInterval = 100 * time.Millisecond
|
|
||||||
|
|
||||||
// Default max wait time
|
|
||||||
defaultMaxWait = 60 * time.Second
|
|
||||||
|
|
||||||
// Default extra wait slots beyond concurrency limit
|
// Default extra wait slots beyond concurrency limit
|
||||||
defaultExtraWaitSlots = 20
|
defaultExtraWaitSlots = 20
|
||||||
)
|
)
|
||||||
@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
|||||||
|
|
||||||
// AcquireResult represents the result of acquiring a concurrency slot
|
// AcquireResult represents the result of acquiring a concurrency slot
|
||||||
type AcquireResult struct {
|
type AcquireResult struct {
|
||||||
Acquired bool
|
Acquired bool
|
||||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
|||||||
|
|
||||||
if acquired {
|
if acquired {
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
ReleaseFunc: func() {
|
ReleaseFunc: func() {
|
||||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
|||||||
|
|
||||||
if acquired {
|
if acquired {
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
ReleaseFunc: func() {
|
ReleaseFunc: func() {
|
||||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls dial: %w", err)
|
return fmt.Errorf("tls dial: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
client, err := smtp.NewClient(conn, host)
|
client, err := smtp.NewClient(conn, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("new smtp client: %w", err)
|
return fmt.Errorf("new smtp client: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
if err = client.Auth(auth); err != nil {
|
if err = client.Auth(auth); err != nil {
|
||||||
return fmt.Errorf("smtp auth: %w", err)
|
return fmt.Errorf("smtp auth: %w", err)
|
||||||
@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls connection failed: %w", err)
|
return fmt.Errorf("tls connection failed: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
client, err := smtp.NewClient(conn, config.Host)
|
client, err := smtp.NewClient(conn, config.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||||
if err = client.Auth(auth); err != nil {
|
if err = client.Auth(auth); err != nil {
|
||||||
@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("smtp connection failed: %w", err)
|
return fmt.Errorf("smtp connection failed: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||||
if err = client.Auth(auth); err != nil {
|
if err = client.Auth(auth); err != nil {
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ var allowedHeaders = map[string]bool{
|
|||||||
"anthropic-beta": true,
|
"anthropic-beta": true,
|
||||||
"accept-language": true,
|
"accept-language": true,
|
||||||
"sec-fetch-mode": true,
|
"sec-fetch-mode": true,
|
||||||
"accept-encoding": true,
|
|
||||||
"user-agent": true,
|
"user-agent": true,
|
||||||
"content-type": true,
|
"content-type": true,
|
||||||
}
|
}
|
||||||
@@ -84,7 +83,6 @@ type GatewayService struct {
|
|||||||
userSubRepo ports.UserSubscriptionRepository
|
userSubRepo ports.UserSubscriptionRepository
|
||||||
cache ports.GatewayCache
|
cache ports.GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
oauthService *OAuthService
|
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
@@ -100,7 +98,6 @@ func NewGatewayService(
|
|||||||
userSubRepo ports.UserSubscriptionRepository,
|
userSubRepo ports.UserSubscriptionRepository,
|
||||||
cache ports.GatewayCache,
|
cache ports.GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
oauthService *OAuthService,
|
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
@@ -114,7 +111,6 @@ func NewGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
oauthService: oauthService,
|
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
@@ -125,13 +121,13 @@ func NewGatewayService(
|
|||||||
|
|
||||||
// GenerateSessionHash 从请求体计算粘性会话hash
|
// GenerateSessionHash 从请求体计算粘性会话hash
|
||||||
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||||
var req map[string]interface{}
|
var req map[string]any
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
||||||
if metadata, ok := req["metadata"].(map[string]interface{}); ok {
|
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||||
if userID, ok := metadata["user_id"].(string); ok {
|
if userID, ok := metadata["user_id"].(string); ok {
|
||||||
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||||
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
||||||
@@ -155,8 +151,8 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. 最后fallback: 使用第一条消息
|
// 4. 最后fallback: 使用第一条消息
|
||||||
if messages, ok := req["messages"].([]interface{}); ok && len(messages) > 0 {
|
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
||||||
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
|
if firstMsg, ok := messages[0].(map[string]any); ok {
|
||||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||||
if msgText != "" {
|
if msgText != "" {
|
||||||
return s.hashContent(msgText)
|
return s.hashContent(msgText)
|
||||||
@@ -167,14 +163,14 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractCacheableContent(req map[string]interface{}) string {
|
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||||
var content string
|
var content string
|
||||||
|
|
||||||
// 检查system中的cacheable内容
|
// 检查system中的cacheable内容
|
||||||
if system, ok := req["system"].([]interface{}); ok {
|
if system, ok := req["system"].([]any); ok {
|
||||||
for _, part := range system {
|
for _, part := range system {
|
||||||
if partMap, ok := part.(map[string]interface{}); ok {
|
if partMap, ok := part.(map[string]any); ok {
|
||||||
if cc, ok := partMap["cache_control"].(map[string]interface{}); ok {
|
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||||
if cc["type"] == "ephemeral" {
|
if cc["type"] == "ephemeral" {
|
||||||
if text, ok := partMap["text"].(string); ok {
|
if text, ok := partMap["text"].(string); ok {
|
||||||
content += text
|
content += text
|
||||||
@@ -186,13 +182,13 @@ func (s *GatewayService) extractCacheableContent(req map[string]interface{}) str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查messages中的cacheable内容
|
// 检查messages中的cacheable内容
|
||||||
if messages, ok := req["messages"].([]interface{}); ok {
|
if messages, ok := req["messages"].([]any); ok {
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
if msgMap, ok := msg.(map[string]interface{}); ok {
|
if msgMap, ok := msg.(map[string]any); ok {
|
||||||
if msgContent, ok := msgMap["content"].([]interface{}); ok {
|
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||||
for _, part := range msgContent {
|
for _, part := range msgContent {
|
||||||
if partMap, ok := part.(map[string]interface{}); ok {
|
if partMap, ok := part.(map[string]any); ok {
|
||||||
if cc, ok := partMap["cache_control"].(map[string]interface{}); ok {
|
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||||
if cc["type"] == "ephemeral" {
|
if cc["type"] == "ephemeral" {
|
||||||
// 找到cacheable内容,提取第一条消息的文本
|
// 找到cacheable内容,提取第一条消息的文本
|
||||||
return s.extractTextFromContent(msgMap["content"])
|
return s.extractTextFromContent(msgMap["content"])
|
||||||
@@ -208,14 +204,14 @@ func (s *GatewayService) extractCacheableContent(req map[string]interface{}) str
|
|||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractTextFromSystem(system interface{}) string {
|
func (s *GatewayService) extractTextFromSystem(system any) string {
|
||||||
switch v := system.(type) {
|
switch v := system.(type) {
|
||||||
case string:
|
case string:
|
||||||
return v
|
return v
|
||||||
case []interface{}:
|
case []any:
|
||||||
var texts []string
|
var texts []string
|
||||||
for _, part := range v {
|
for _, part := range v {
|
||||||
if partMap, ok := part.(map[string]interface{}); ok {
|
if partMap, ok := part.(map[string]any); ok {
|
||||||
if text, ok := partMap["text"].(string); ok {
|
if text, ok := partMap["text"].(string); ok {
|
||||||
texts = append(texts, text)
|
texts = append(texts, text)
|
||||||
}
|
}
|
||||||
@@ -226,14 +222,14 @@ func (s *GatewayService) extractTextFromSystem(system interface{}) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractTextFromContent(content interface{}) string {
|
func (s *GatewayService) extractTextFromContent(content any) string {
|
||||||
switch v := content.(type) {
|
switch v := content.(type) {
|
||||||
case string:
|
case string:
|
||||||
return v
|
return v
|
||||||
case []interface{}:
|
case []any:
|
||||||
var texts []string
|
var texts []string
|
||||||
for _, part := range v {
|
for _, part := range v {
|
||||||
if partMap, ok := part.(map[string]interface{}); ok {
|
if partMap, ok := part.(map[string]any); ok {
|
||||||
if partMap["type"] == "text" {
|
if partMap["type"] == "text" {
|
||||||
if text, ok := partMap["text"].(string); ok {
|
if text, ok := partMap["text"].(string); ok {
|
||||||
texts = append(texts, text)
|
texts = append(texts, text)
|
||||||
@@ -253,7 +249,7 @@ func (s *GatewayService) hashContent(content string) string {
|
|||||||
|
|
||||||
// replaceModelInBody 替换请求体中的model字段
|
// replaceModelInBody 替换请求体中的model字段
|
||||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||||
var req map[string]interface{}
|
var req map[string]any
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
@@ -281,7 +277,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
// 同时检查模型支持
|
// 同时检查模型支持
|
||||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
// 续期粘性会话
|
// 续期粘性会话
|
||||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -331,7 +329,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return selected, nil
|
return selected, nil
|
||||||
@@ -411,7 +411,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
@@ -557,7 +557,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
|||||||
|
|
||||||
// 客户端没传,根据模型生成
|
// 客户端没传,根据模型生成
|
||||||
var modelID string
|
var modelID string
|
||||||
var reqMap map[string]interface{}
|
var reqMap map[string]any
|
||||||
if json.Unmarshal(body, &reqMap) == nil {
|
if json.Unmarshal(body, &reqMap) == nil {
|
||||||
if m, ok := reqMap["model"].(string); ok {
|
if m, ok := reqMap["model"].(string); ok {
|
||||||
modelID = m
|
modelID = m
|
||||||
@@ -678,7 +678,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转发行
|
// 转发行
|
||||||
fmt.Fprintf(w, "%s\n", line)
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
// 解析usage数据
|
// 解析usage数据
|
||||||
@@ -707,7 +709,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
var event map[string]interface{}
|
var event map[string]any
|
||||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
@@ -717,7 +719,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, ok := event["message"].(map[string]interface{})
|
msg, ok := event["message"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
@@ -799,7 +801,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
|
|
||||||
// replaceModelInResponseBody 替换响应体中的model字段
|
// replaceModelInResponseBody 替换响应体中的model字段
|
||||||
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||||
var resp map[string]interface{}
|
var resp map[string]any
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
@@ -985,7 +987,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||||
return fmt.Errorf("upstream request failed: %w", err)
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetStats 获取分组统计信息
|
// GetStats 获取分组统计信息
|
||||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]interface{}, error) {
|
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
|
||||||
group, err := s.groupRepo.GetByID(ctx, id)
|
group, err := s.groupRepo.GetByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
@@ -182,7 +182,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]inter
|
|||||||
return nil, fmt.Errorf("get account count: %w", err)
|
return nil, fmt.Errorf("get account count: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats := map[string]interface{}{
|
stats := map[string]any{
|
||||||
"id": group.ID,
|
"id": group.ID,
|
||||||
"name": group.Name,
|
"name": group.Name,
|
||||||
"rate_multiplier": group.RateMultiplier,
|
"rate_multiplier": group.RateMultiplier,
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
var (
|
var (
|
||||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||||
@@ -150,12 +149,12 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析JSON
|
// 解析JSON
|
||||||
var reqMap map[string]interface{}
|
var reqMap map[string]any
|
||||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, ok := reqMap["metadata"].(map[string]interface{})
|
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -515,11 +515,11 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetStatus 获取服务状态
|
// GetStatus 获取服务状态
|
||||||
func (s *PricingService) GetStatus() map[string]interface{} {
|
func (s *PricingService) GetStatus() map[string]any {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
return map[string]interface{}{
|
return map[string]any{
|
||||||
"model_count": len(s.pricingData),
|
"model_count": len(s.pricingData),
|
||||||
"last_updated": s.lastUpdated,
|
"last_updated": s.lastUpdated,
|
||||||
"local_hash": s.localHash[:min(8, len(s.localHash))],
|
"local_hash": s.localHash[:min(8, len(s.localHash))],
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,12 +359,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetStats 获取兑换码统计信息
|
// GetStats 获取兑换码统计信息
|
||||||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
|
func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
|
||||||
// TODO: 实现统计逻辑
|
// TODO: 实现统计逻辑
|
||||||
// 统计未使用、已使用的兑换码数量
|
// 统计未使用、已使用的兑换码数量
|
||||||
// 统计总面值等
|
// 统计总面值等
|
||||||
|
|
||||||
stats := map[string]interface{}{
|
stats := map[string]any{
|
||||||
"total_codes": 0,
|
"total_codes": 0,
|
||||||
"unused_codes": 0,
|
"unused_codes": 0,
|
||||||
"used_codes": 0,
|
"used_codes": 0,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
|||||||
}
|
}
|
||||||
newNotes += input.Notes
|
newNotes += input.Notes
|
||||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||||
// 备注更新失败不影响主流程
|
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type TokenRefresher interface {
|
|||||||
|
|
||||||
// Refresh 执行token刷新,返回更新后的credentials
|
// Refresh 执行token刷新,返回更新后的credentials
|
||||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||||
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
Refresh(ctx context.Context, account *model.Account) (map[string]any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||||
@@ -61,14 +61,14 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
|
|||||||
|
|
||||||
// Refresh 执行token刷新
|
// Refresh 执行token刷新
|
||||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保留现有credentials中的所有字段
|
// 保留现有credentials中的所有字段
|
||||||
newCredentials := make(map[string]interface{})
|
newCredentials := make(map[string]any)
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
newCredentials[k] = v
|
newCredentials[k] = v
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ var (
|
|||||||
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
||||||
)
|
)
|
||||||
|
|
||||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
|
||||||
|
|
||||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||||
type TurnstileVerifier interface {
|
type TurnstileVerifier interface {
|
||||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create temp dir: %w", err)
|
return fmt.Errorf("failed to create temp dir: %w", err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tempDir)
|
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||||
|
|
||||||
// Download archive
|
// Download archive
|
||||||
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
||||||
@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
|||||||
backupPath := exePath + ".backup"
|
backupPath := exePath + ".backup"
|
||||||
|
|
||||||
// Remove old backup if exists
|
// Remove old backup if exists
|
||||||
os.Remove(backupPath)
|
_ = os.Remove(backupPath)
|
||||||
|
|
||||||
// Step 1: Move current binary to backup
|
// Step 1: Move current binary to backup
|
||||||
if err := os.Rename(exePath, backupPath); err != nil {
|
if err := os.Rename(exePath, backupPath); err != nil {
|
||||||
@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
if _, err := io.Copy(h, f); err != nil {
|
if _, err := io.Copy(h, f); err != nil {
|
||||||
@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
var reader io.Reader = f
|
var reader io.Reader = f
|
||||||
|
|
||||||
@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer gzr.Close()
|
defer func() { _ = gzr.Close() }()
|
||||||
reader = gzr
|
reader = gzr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
// Use LimitReader to prevent decompression bombs
|
// Use LimitReader to prevent decompression bombs
|
||||||
limited := io.LimitReader(tr, maxBinarySize)
|
limited := io.LimitReader(tr, maxBinarySize)
|
||||||
if _, err := io.Copy(out, limited); err != nil {
|
if _, err := io.Copy(out, limited); err != nil {
|
||||||
out.Close()
|
_ = out.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := out.Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
out.Close()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
limited := io.LimitReader(reader, maxBinarySize)
|
limited := io.LimitReader(reader, maxBinarySize)
|
||||||
_, err = io.Copy(out, limited)
|
if _, err := io.Copy(out, limited); err != nil {
|
||||||
return err
|
_ = out.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return out.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||||
@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data, _ := json.Marshal(cacheData)
|
data, _ := json.Marshal(cacheData)
|
||||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
_ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compareVersions compares two semantic versions
|
// compareVersions compares two semantic versions
|
||||||
@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int {
|
|||||||
parts := strings.Split(v, ".")
|
parts := strings.Split(v, ".")
|
||||||
result := [3]int{0, 0, 0}
|
result := [3]int{0, 0, 0}
|
||||||
for i := 0; i < len(parts) && i < 3; i++ {
|
for i := 0; i < len(parts) && i < 3; i++ {
|
||||||
fmt.Sscanf(parts[i], "%d", &result[i])
|
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||||
|
result[i] = parsed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetDailyStats 获取每日使用统计(最近N天)
|
// GetDailyStats 获取每日使用统计(最近N天)
|
||||||
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
|
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]any, error) {
|
||||||
endTime := time.Now()
|
endTime := time.Now()
|
||||||
startTime := endTime.AddDate(0, 0, -days)
|
startTime := endTime.AddDate(0, 0, -days)
|
||||||
|
|
||||||
@@ -227,13 +227,13 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 计算平均值并转换为数组
|
// 计算平均值并转换为数组
|
||||||
result := make([]map[string]interface{}, 0, len(dailyStats))
|
result := make([]map[string]any, 0, len(dailyStats))
|
||||||
for date, stats := range dailyStats {
|
for date, stats := range dailyStats {
|
||||||
if stats.TotalRequests > 0 {
|
if stats.TotalRequests > 0 {
|
||||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = append(result, map[string]interface{}{
|
result = append(result, map[string]any{
|
||||||
"date": date,
|
"date": date,
|
||||||
"total_requests": stats.TotalRequests,
|
"total_requests": stats.TotalRequests,
|
||||||
"total_input_tokens": stats.TotalInputTokens,
|
"total_input_tokens": stats.TotalInputTokens,
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sub2api/internal/config"
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/pagination"
|
"sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/service/ports"
|
"sub2api/internal/service/ports"
|
||||||
@@ -34,14 +33,12 @@ type ChangePasswordRequest struct {
|
|||||||
// UserService 用户服务
|
// UserService 用户服务
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
userRepo ports.UserRepository
|
userRepo ports.UserRepository
|
||||||
cfg *config.Config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserService 创建用户服务实例
|
// NewUserService 创建用户服务实例
|
||||||
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
|
func NewUserService(userRepo ports.UserRepository) *UserService {
|
||||||
return &UserService{
|
return &UserService{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
cfg: cfg,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -352,4 +352,3 @@ func install(c *gin.Context) {
|
|||||||
"restart": true,
|
"restart": true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ import (
|
|||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config paths
|
// Config paths
|
||||||
@@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get db instance: %w", err)
|
return fmt.Errorf("failed to get db instance: %w", err)
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer func() {
|
||||||
|
if sqlDB == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now connect to the target database to verify
|
// Now connect to the target database to verify
|
||||||
sqlDB.Close()
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
sqlDB = nil
|
||||||
|
|
||||||
targetDSN := fmt.Sprintf(
|
targetDSN := fmt.Sprintf(
|
||||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||||
@@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get target db instance: %w", err)
|
return fmt.Errorf("failed to get target db instance: %w", err)
|
||||||
}
|
}
|
||||||
defer targetSqlDB.Close()
|
defer func() {
|
||||||
|
if err := targetSqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel2()
|
defer cancel2()
|
||||||
@@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error {
|
|||||||
Password: cfg.Password,
|
Password: cfg.Password,
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
})
|
})
|
||||||
defer rdb.Close()
|
defer func() {
|
||||||
|
if err := rdb.Close(); err != nil {
|
||||||
|
log.Printf("failed to close redis client: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error {
|
|||||||
|
|
||||||
// Generate JWT secret if not provided
|
// Generate JWT secret if not provided
|
||||||
if cfg.JWT.Secret == "" {
|
if cfg.JWT.Secret == "" {
|
||||||
cfg.JWT.Secret = generateSecret(32)
|
secret, err := generateSecret(32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||||
|
}
|
||||||
|
cfg.JWT.Secret = secret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test connections
|
// Test connections
|
||||||
@@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer func() {
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 使用 model 包的 AutoMigrate,确保模型定义统一
|
// 使用 model 包的 AutoMigrate,确保模型定义统一
|
||||||
return model.AutoMigrate(db)
|
return model.AutoMigrate(db)
|
||||||
@@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer func() {
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Check if admin already exists
|
// Check if admin already exists
|
||||||
var count int64
|
var count int64
|
||||||
@@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error {
|
|||||||
return os.WriteFile(ConfigFile, data, 0600)
|
return os.WriteFile(ConfigFile, data, 0600)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSecret(length int) string {
|
func generateSecret(length int) (string, error) {
|
||||||
bytes := make([]byte, length)
|
bytes := make([]byte, length)
|
||||||
rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
return hex.EncodeToString(bytes)
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
@@ -431,13 +463,21 @@ func AutoSetupFromEnv() error {
|
|||||||
|
|
||||||
// Generate JWT secret if not provided
|
// Generate JWT secret if not provided
|
||||||
if cfg.JWT.Secret == "" {
|
if cfg.JWT.Secret == "" {
|
||||||
cfg.JWT.Secret = generateSecret(32)
|
secret, err := generateSecret(32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||||
|
}
|
||||||
|
cfg.JWT.Secret = secret
|
||||||
log.Println("Generated JWT secret automatically")
|
log.Println("Generated JWT secret automatically")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate admin password if not provided
|
// Generate admin password if not provided
|
||||||
if cfg.Admin.Password == "" {
|
if cfg.Admin.Password == "" {
|
||||||
cfg.Admin.Password = generateSecret(16)
|
password, err := generateSecret(16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate admin password: %w", err)
|
||||||
|
}
|
||||||
|
cfg.Admin.Password = password
|
||||||
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
||||||
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||||
}
|
}
|
||||||
|
|||||||
20
backend/internal/web/embed_off.go
Normal file
20
backend/internal/web/embed_off.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
//go:build !embed
|
||||||
|
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasEmbeddedFrontend() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build embed
|
||||||
|
|
||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,8 +15,6 @@ import (
|
|||||||
//go:embed all:dist
|
//go:embed all:dist
|
||||||
var frontendFS embed.FS
|
var frontendFS embed.FS
|
||||||
|
|
||||||
// ServeEmbeddedFrontend returns a Gin handler that serves embedded frontend assets
|
|
||||||
// and handles SPA routing by falling back to index.html for non-API routes.
|
|
||||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||||
distFS, err := fs.Sub(frontendFS, "dist")
|
distFS, err := fs.Sub(frontendFS, "dist")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -25,7 +25,6 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
// Skip API and gateway routes
|
|
||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/setup/") ||
|
strings.HasPrefix(path, "/setup/") ||
|
||||||
@@ -34,20 +33,18 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to serve static file
|
|
||||||
cleanPath := strings.TrimPrefix(path, "/")
|
cleanPath := strings.TrimPrefix(path, "/")
|
||||||
if cleanPath == "" {
|
if cleanPath == "" {
|
||||||
cleanPath = "index.html"
|
cleanPath = "index.html"
|
||||||
}
|
}
|
||||||
|
|
||||||
if file, err := distFS.Open(cleanPath); err == nil {
|
if file, err := distFS.Open(cleanPath); err == nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SPA fallback: serve index.html for all other routes
|
|
||||||
serveIndexHTML(c, distFS)
|
serveIndexHTML(c, distFS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -59,7 +56,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer func() { _ = file.Close() }()
|
||||||
|
|
||||||
content, err := io.ReadAll(file)
|
content, err := io.ReadAll(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -72,7 +69,6 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEmbeddedFrontend checks if frontend assets are embedded
|
|
||||||
func HasEmbeddedFrontend() bool {
|
func HasEmbeddedFrontend() bool {
|
||||||
_, err := frontendFS.ReadFile("dist/index.html")
|
_, err := frontendFS.ReadFile("dist/index.html")
|
||||||
return err == nil
|
return err == nil
|
||||||
Reference in New Issue
Block a user