mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 08:20:23 +08:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dacf3a2a6e | ||
|
|
e6add93ae3 | ||
|
|
b2273ec695 | ||
|
|
aa89777dda | ||
|
|
1e1f3c0c74 | ||
|
|
1fab9204eb | ||
|
|
dbd3e71637 | ||
|
|
974f67211b | ||
|
|
0338c83b90 | ||
|
|
c6b3de1199 |
38
.github/workflows/backend-ci.yml
vendored
Normal file
38
.github/workflows/backend-ci.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: Run tests
|
||||
working-directory: backend
|
||||
run: go test ./...
|
||||
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
args: --timeout=5m
|
||||
working-directory: backend
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -81,7 +81,12 @@ build/
|
||||
release/
|
||||
|
||||
# 后端嵌入的前端构建产物
|
||||
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
|
||||
# while still ignoring generated frontend build outputs.
|
||||
backend/internal/web/dist/
|
||||
!backend/internal/web/dist/
|
||||
backend/internal/web/dist/*
|
||||
!backend/internal/web/dist/.keep
|
||||
|
||||
# 后端运行时缓存数据
|
||||
backend/data/
|
||||
@@ -92,4 +97,4 @@ backend/data/
|
||||
tests
|
||||
CLAUDE.md
|
||||
.claude
|
||||
scripts
|
||||
scripts
|
||||
|
||||
@@ -11,6 +11,8 @@ builds:
|
||||
dir: backend
|
||||
main: ./cmd/server
|
||||
binary: sub2api
|
||||
flags:
|
||||
- -tags=embed
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -40,14 +40,15 @@ WORKDIR /app/backend
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy frontend dist from previous stage
|
||||
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Copy backend source
|
||||
# Copy backend source first
|
||||
COPY backend/ ./
|
||||
|
||||
# Build the binary (BuildType=release for CI builds)
|
||||
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
|
||||
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-tags embed \
|
||||
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||
-o /app/sub2api \
|
||||
./cmd/server
|
||||
|
||||
16
README.md
16
README.md
@@ -220,21 +220,21 @@ cd sub2api
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# Output will be in ../backend/internal/web/dist/
|
||||
|
||||
# 3. Copy frontend build to backend (for embedding)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 4. Build backend (requires frontend dist to be present)
|
||||
# 3. Build backend with embedded frontend
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. Create configuration file
|
||||
# 4. Create configuration file
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. Edit configuration
|
||||
# 5. Edit configuration
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
|
||||
|
||||
**Key configuration in `config.yaml`:**
|
||||
|
||||
```yaml
|
||||
@@ -265,7 +265,7 @@ default:
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. Run the application
|
||||
# 6. Run the application
|
||||
./sub2api
|
||||
```
|
||||
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -220,21 +220,21 @@ cd sub2api
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# 构建产物输出到 ../backend/internal/web/dist/
|
||||
|
||||
# 3. 复制前端构建产物到后端(用于嵌入)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 4. 编译后端(需要前端 dist 目录存在)
|
||||
# 3. 编译后端(嵌入前端)
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. 创建配置文件
|
||||
# 4. 创建配置文件
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. 编辑配置
|
||||
# 5. 编辑配置
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
|
||||
|
||||
**`config.yaml` 关键配置:**
|
||||
|
||||
```yaml
|
||||
@@ -265,7 +265,7 @@ default:
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. 运行应用
|
||||
# 6. 运行应用
|
||||
./sub2api
|
||||
```
|
||||
|
||||
|
||||
587
backend/.golangci.yml
Normal file
587
backend/.golangci.yml
Normal file
@@ -0,0 +1,587 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- depguard
|
||||
- errcheck
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- unused
|
||||
|
||||
settings:
|
||||
depguard:
|
||||
rules:
|
||||
# Enforce: service must not depend on repository.
|
||||
service-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- internal/service/**
|
||||
deny:
|
||||
- pkg: sub2api/internal/repository
|
||||
desc: "service must not import repository"
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: true
|
||||
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-blank: false
|
||||
# To disable the errcheck built-in exclude list.
|
||||
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
# Default: false
|
||||
disable-default-exclusions: true
|
||||
# List of functions to exclude from checking, where each entry is a single function to exclude.
|
||||
# See https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
exclude-functions:
|
||||
- io/ioutil.ReadFile
|
||||
- io.Copy(*bytes.Buffer)
|
||||
- io.Copy(os.Stdout)
|
||||
- fmt.Println
|
||||
- fmt.Print
|
||||
- fmt.Printf
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintf
|
||||
- fmt.Fprintln
|
||||
# Display function signature instead of selector.
|
||||
# Default: false
|
||||
verbose: true
|
||||
ineffassign:
|
||||
# Check escaping variables of type error, may cause false positives.
|
||||
# Default: false
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
checks:
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# Incorrect or missing package comment.
|
||||
# https://staticcheck.dev/docs/checks/#ST1000
|
||||
- ST1000
|
||||
# Dot imports are discouraged.
|
||||
# https://staticcheck.dev/docs/checks/#ST1001
|
||||
- ST1001
|
||||
# Poorly chosen identifier.
|
||||
# https://staticcheck.dev/docs/checks/#ST1003
|
||||
- ST1003
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# The documentation of an exported function should start with the function's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1020
|
||||
- ST1020
|
||||
# The documentation of an exported type should start with type's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1021
|
||||
- ST1021
|
||||
# The documentation of an exported variable or constant should start with variable's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1022
|
||||
- ST1022
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
settings:
|
||||
gofmt:
|
||||
# Simplify code: gofmt with `-s` option.
|
||||
# Default: true
|
||||
simplify: false
|
||||
# Apply the rewrite rules to the source before reformatting.
|
||||
# https://pkg.go.dev/cmd/gofmt
|
||||
# Default: []
|
||||
rewrite-rules:
|
||||
- pattern: 'interface{}'
|
||||
replacement: 'any'
|
||||
- pattern: 'a[b:len(a)]'
|
||||
replacement: 'a[b:]'
|
||||
@@ -1,6 +1,16 @@
|
||||
.PHONY: wire
|
||||
.PHONY: wire build build-embed
|
||||
|
||||
wire:
|
||||
@echo "生成 Wire 代码..."
|
||||
@cd cmd/server && go generate
|
||||
@echo "Wire 代码生成完成"
|
||||
@echo "Wire 代码生成完成"
|
||||
|
||||
build:
|
||||
@echo "构建后端(不嵌入前端)..."
|
||||
@go build -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server"
|
||||
|
||||
build-embed:
|
||||
@echo "构建后端(嵌入前端)..."
|
||||
@go build -tags embed -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server (with embedded frontend)"
|
||||
@@ -48,7 +48,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(authService)
|
||||
userService := service.NewUserService(userRepository, configConfig)
|
||||
userService := service.NewUserService(userRepository)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
@@ -67,22 +67,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
dashboardHandler := admin.NewDashboardHandler(usageLogRepository)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
@@ -103,16 +103,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
groupService := service.NewGroupService(groupRepository)
|
||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||
proxyService := service.NewProxyService(proxyRepository)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
services := &service.Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
|
||||
@@ -52,7 +52,7 @@ type PricingConfig struct {
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func setDefaults() {
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
@@ -210,10 +210,10 @@ func setDefaults() {
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -13,14 +13,12 @@ import (
|
||||
// OAuthHandler handles OAuth-related operations for accounts
|
||||
type OAuthHandler struct {
|
||||
oauthService *service.OAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
|
||||
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
oauthService: oauthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,29 +44,29 @@ func NewAccountHandler(adminService service.AdminService, oauthService *service.
|
||||
|
||||
// CreateAccountRequest represents create account request
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials" binding:"required"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents update account request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
@@ -242,7 +240,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials := make(map[string]interface{})
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
@@ -573,7 +571,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// For API Key accounts: return models based on model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
// No mapping configured, return default models
|
||||
response.Success(c, claude.DefaultModels)
|
||||
return
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,17 +12,15 @@ import (
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
adminService service.AdminService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
func NewDashboardHandler(usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
adminService: adminService,
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,7 +255,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -286,7 +283,7 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
}
|
||||
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
|
||||
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_value_distributed": 0.0,
|
||||
"by_type": gin.H{
|
||||
"balance": 0,
|
||||
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Write data rows
|
||||
for _, code := range codes {
|
||||
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
writer.Write([]string{
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
code.Type,
|
||||
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
usedBy,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||
|
||||
@@ -193,7 +193,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []interface{}{})
|
||||
response.Success(c, []any{})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -268,7 +268,9 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@@ -414,7 +416,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
if ok {
|
||||
// Send error event in SSE format
|
||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
fmt.Fprint(c.Writer, errorEvent)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
@@ -574,11 +578,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
|
||||
@@ -358,7 +358,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
// JSONB 用于存储JSONB数据
|
||||
type JSONB map[string]interface{}
|
||||
type JSONB map[string]any
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
func (j *JSONB) Scan(value any) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
@@ -40,8 +40,8 @@ type Account struct {
|
||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
@@ -145,7 +145,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
// 处理map[string]interface{}类型
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
@@ -163,7 +163,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
// 如果没有设置模型映射,则支持所有模型
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
return true // 没有映射配置,支持所有模型
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
@@ -174,7 +174,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// 如果没有映射,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
@@ -231,7 +231,7 @@ func (a *Account) GetCustomErrorCodes() []int {
|
||||
return nil
|
||||
}
|
||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
if arr, ok := raw.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
// JSON 数字默认解析为 float64
|
||||
|
||||
@@ -13,13 +13,13 @@ const (
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
|
||||
// 订阅功能字段
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||
|
||||
@@ -9,15 +9,15 @@ import (
|
||||
type RedeemCode struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 订阅类型专用字段
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||
|
||||
// 关联
|
||||
@@ -40,8 +40,10 @@ func (r *RedeemCode) CanUse() bool {
|
||||
}
|
||||
|
||||
// GenerateRedeemCode 生成唯一的兑换码
|
||||
func GenerateRedeemCode() string {
|
||||
func GenerateRedeemCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
@@ -19,17 +19,17 @@ func (Setting) TableName() string {
|
||||
// 设置Key常量
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
|
||||
@@ -37,7 +37,7 @@ type UsageLog struct {
|
||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||
|
||||
@@ -9,22 +9,22 @@ import (
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items interface{} `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
Items any `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
func Success(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -33,7 +33,7 @@ func Success(c *gin.Context, data interface{}) {
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
func Created(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -75,7 +75,7 @@ func InternalError(c *gin.Context, message string) {
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
|
||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
@@ -99,7 +99,7 @@ type PaginationResult struct {
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
|
||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
|
||||
@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
|
||||
|
||||
func TestTimeNowAffected(t *testing.T) {
|
||||
// Reset to UTC first
|
||||
Init("UTC")
|
||||
if err := Init("UTC"); err != nil {
|
||||
t.Fatalf("Init failed with UTC: %v", err)
|
||||
}
|
||||
utcNow := time.Now()
|
||||
|
||||
// Switch to Shanghai (UTC+8)
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
shanghaiNow := time.Now()
|
||||
|
||||
// The times should be the same instant, but different timezone representation
|
||||
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToday(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
today := Today()
|
||||
now := Now()
|
||||
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
// Create a time at 15:30:45
|
||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
|
||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||
// and why StartOfDay is more reliable for timezone-aware code
|
||||
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
now := Now()
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
|
||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusError,
|
||||
"error_message": errorMsg,
|
||||
}).Error
|
||||
@@ -226,7 +226,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
@@ -241,7 +241,7 @@ func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": nil,
|
||||
"rate_limit_reset_at": nil,
|
||||
"overload_until": nil,
|
||||
@@ -250,7 +250,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
|
||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]interface{}{
|
||||
updates := map[string]any{
|
||||
"session_window_status": status,
|
||||
}
|
||||
if start != nil {
|
||||
|
||||
@@ -143,7 +143,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
|
||||
@@ -64,7 +64,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
@@ -155,7 +155,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
|
||||
@@ -19,7 +19,11 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get default transport")
|
||||
}
|
||||
transport = transport.Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
@@ -43,7 +47,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
|
||||
@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
|
||||
@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
|
||||
now := time.Now()
|
||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusUsed,
|
||||
"used_by": userID,
|
||||
"used_at": now,
|
||||
|
||||
@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
var result service.TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
|
||||
@@ -185,7 +185,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
||||
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
||||
@@ -197,7 +197,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": 0,
|
||||
"daily_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
@@ -208,7 +208,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
|
||||
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"weekly_usage_usd": 0,
|
||||
"weekly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
@@ -219,7 +219,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
|
||||
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"monthly_usage_usd": 0,
|
||||
"monthly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
@@ -230,7 +230,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
||||
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"daily_window_start": activateTime,
|
||||
"weekly_window_start": activateTime,
|
||||
"monthly_window_start": activateTime,
|
||||
@@ -242,7 +242,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
|
||||
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
@@ -252,7 +252,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
|
||||
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"expires_at": newExpiresAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
@@ -262,7 +262,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
|
||||
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"notes": notes,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
@@ -281,7 +281,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
|
||||
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.SubscriptionStatusExpired,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
|
||||
@@ -17,27 +17,27 @@ var (
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Credentials *map[string]interface{} `json:"credentials"`
|
||||
Extra *map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name *string `json:"name"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
|
||||
@@ -51,22 +51,29 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() string {
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
hex64 := hex.EncodeToString(bytes)
|
||||
sessionUUID := uuid.New().String()
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||
}
|
||||
|
||||
// createTestPayload creates a Claude Code style test request payload
|
||||
func createTestPayload(modelID string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func createTestPayload(modelID string) (map[string]any, error) {
|
||||
sessionID, err := generateSessionString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"model": modelID,
|
||||
"messages": []map[string]interface{}{
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
@@ -77,7 +84,7 @@ func createTestPayload(modelID string) map[string]interface{} {
|
||||
},
|
||||
},
|
||||
},
|
||||
"system": []map[string]interface{}{
|
||||
"system": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
@@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} {
|
||||
},
|
||||
},
|
||||
"metadata": map[string]string{
|
||||
"user_id": generateSessionString(),
|
||||
"user_id": sessionID,
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"temperature": 1,
|
||||
"stream": true,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
@@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping != nil && len(mapping) > 0 {
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
@@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create Claude Code style payload (same for all account types)
|
||||
payload := createTestPayload(testModelID)
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
@@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -252,7 +262,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -261,7 +271,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
|
||||
switch eventType {
|
||||
case "content_block_delta":
|
||||
if delta, ok := data["delta"].(map[string]interface{}); ok {
|
||||
if delta, ok := data["delta"].(map[string]any); ok {
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
@@ -271,7 +281,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]interface{}); ok {
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
@@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||
log.Printf("failed to write SSE event: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
|
||||
@@ -70,16 +70,14 @@ type ClaudeUsageFetcher interface {
|
||||
type AccountUsageService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
oauthService *OAuthService
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
oauthService: oauthService,
|
||||
usageFetcher: usageFetcher,
|
||||
}
|
||||
}
|
||||
@@ -98,8 +96,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
if account.CanGetUsage() {
|
||||
// 检查缓存
|
||||
if cached, ok := usageCacheMap.Load(accountID); ok {
|
||||
cache := cached.(*usageCache)
|
||||
if time.Since(cache.timestamp) < cacheTTL {
|
||||
cache, ok := cached.(*usageCache)
|
||||
if !ok {
|
||||
usageCacheMap.Delete(accountID)
|
||||
} else if time.Since(cache.timestamp) < cacheTTL {
|
||||
return cache.data, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
@@ -23,7 +24,7 @@ type AdminService interface {
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
||||
@@ -113,8 +114,8 @@ type CreateAccountInput struct {
|
||||
Name string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]interface{}
|
||||
Extra map[string]interface{}
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
@@ -124,8 +125,8 @@ type CreateAccountInput struct {
|
||||
type UpdateAccountInput struct {
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]interface{}
|
||||
Extra map[string]interface{}
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
@@ -192,8 +193,6 @@ type adminServiceImpl struct {
|
||||
proxyRepo ports.ProxyRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
redeemCodeRepo ports.RedeemCodeRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
}
|
||||
@@ -206,8 +205,6 @@ func NewAdminService(
|
||||
proxyRepo ports.ProxyRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
redeemCodeRepo ports.RedeemCodeRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
) AdminService {
|
||||
@@ -218,8 +215,6 @@ func NewAdminService(
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
}
|
||||
@@ -309,7 +304,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -317,8 +314,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
// Create adjustment records for balance/concurrency changes
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
@@ -327,15 +329,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminConcurrency,
|
||||
Value: float64(concurrencyDiff),
|
||||
Status: model.StatusUsed,
|
||||
@@ -344,8 +350,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,7 +393,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -404,9 +411,9 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
// Return mock data for now
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"period": period,
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
@@ -579,7 +586,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
for _, userID := range affectedUserIDs {
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -646,10 +655,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Type != "" {
|
||||
account.Type = input.Type
|
||||
}
|
||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = model.JSONB(input.Credentials)
|
||||
}
|
||||
if input.Extra != nil && len(input.Extra) > 0 {
|
||||
if len(input.Extra) > 0 {
|
||||
account.Extra = model.JSONB(input.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
@@ -831,8 +840,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
||||
|
||||
codes := make([]model.RedeemCode, 0, input.Count)
|
||||
for i := 0; i < input.Count; i++ {
|
||||
codeValue, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code := model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: codeValue,
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: model.StatusUnused,
|
||||
|
||||
@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
||||
return ErrApiKeyInvalidChars
|
||||
if (c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -278,7 +278,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
|
||||
@@ -259,11 +259,11 @@ func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, es
|
||||
}
|
||||
|
||||
// GetPricingServiceStatus 获取价格服务状态
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]any {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.GetStatus()
|
||||
}
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"model_count": len(s.fallbackPrices),
|
||||
"last_updated": "using fallback",
|
||||
"local_hash": "N/A",
|
||||
|
||||
@@ -9,12 +9,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
// Default max wait time
|
||||
defaultMaxWait = 60 * time.Second
|
||||
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
type AcquireResult struct {
|
||||
Acquired bool
|
||||
Acquired bool
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls dial: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new smtp client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth: %w", err)
|
||||
@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp connection failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
|
||||
@@ -53,7 +53,6 @@ var allowedHeaders = map[string]bool{
|
||||
"anthropic-beta": true,
|
||||
"accept-language": true,
|
||||
"sec-fetch-mode": true,
|
||||
"accept-encoding": true,
|
||||
"user-agent": true,
|
||||
"content-type": true,
|
||||
}
|
||||
@@ -84,7 +83,6 @@ type GatewayService struct {
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
@@ -100,7 +98,6 @@ func NewGatewayService(
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
@@ -114,7 +111,6 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
@@ -125,13 +121,13 @@ func NewGatewayService(
|
||||
|
||||
// GenerateSessionHash 从请求体计算粘性会话hash
|
||||
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
var req map[string]interface{}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
||||
if metadata, ok := req["metadata"].(map[string]interface{}); ok {
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
||||
@@ -155,8 +151,8 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
}
|
||||
|
||||
// 4. 最后fallback: 使用第一条消息
|
||||
if messages, ok := req["messages"].([]interface{}); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
|
||||
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
@@ -167,14 +163,14 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]interface{}) string {
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
var content string
|
||||
|
||||
// 检查system中的cacheable内容
|
||||
if system, ok := req["system"].([]interface{}); ok {
|
||||
if system, ok := req["system"].([]any); ok {
|
||||
for _, part := range system {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
content += text
|
||||
@@ -186,13 +182,13 @@ func (s *GatewayService) extractCacheableContent(req map[string]interface{}) str
|
||||
}
|
||||
|
||||
// 检查messages中的cacheable内容
|
||||
if messages, ok := req["messages"].([]interface{}); ok {
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]interface{}); ok {
|
||||
if msgContent, ok := msgMap["content"].([]interface{}); ok {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
// 找到cacheable内容,提取第一条消息的文本
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
@@ -208,14 +204,14 @@ func (s *GatewayService) extractCacheableContent(req map[string]interface{}) str
|
||||
return content
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromSystem(system interface{}) string {
|
||||
func (s *GatewayService) extractTextFromSystem(system any) string {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
case []any:
|
||||
var texts []string
|
||||
for _, part := range v {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
@@ -226,14 +222,14 @@ func (s *GatewayService) extractTextFromSystem(system interface{}) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromContent(content interface{}) string {
|
||||
func (s *GatewayService) extractTextFromContent(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
case []any:
|
||||
var texts []string
|
||||
for _, part := range v {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if partMap["type"] == "text" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
@@ -253,7 +249,7 @@ func (s *GatewayService) hashContent(content string) string {
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||
var req map[string]interface{}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
@@ -281,7 +277,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -331,7 +329,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -411,7 +411,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -557,7 +557,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
|
||||
// 客户端没传,根据模型生成
|
||||
var modelID string
|
||||
var reqMap map[string]interface{}
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(body, &reqMap) == nil {
|
||||
if m, ok := reqMap["model"].(string); ok {
|
||||
modelID = m
|
||||
@@ -678,7 +678,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
// 转发行
|
||||
fmt.Fprintf(w, "%s\n", line)
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 解析usage数据
|
||||
@@ -707,7 +709,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
@@ -717,7 +719,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
return line
|
||||
}
|
||||
|
||||
msg, ok := event["message"].(map[string]interface{})
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
if !ok {
|
||||
return line
|
||||
}
|
||||
@@ -799,7 +801,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
|
||||
// replaceModelInResponseBody 替换响应体中的model字段
|
||||
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]interface{}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
}
|
||||
@@ -985,7 +987,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -167,7 +167,7 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
// GetStats 获取分组统计信息
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]interface{}, error) {
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -182,7 +182,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]inter
|
||||
return nil, fmt.Errorf("get account count: %w", err)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
stats := map[string]any{
|
||||
"id": group.ID,
|
||||
"name": group.Name,
|
||||
"rate_multiplier": group.RateMultiplier,
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
@@ -150,12 +149,12 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var reqMap map[string]interface{}
|
||||
var reqMap map[string]any
|
||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]interface{})
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -515,11 +515,11 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
|
||||
// GetStatus 获取服务状态
|
||||
func (s *PricingService) GetStatus() map[string]interface{} {
|
||||
func (s *PricingService) GetStatus() map[string]any {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"model_count": len(s.pricingData),
|
||||
"last_updated": s.lastUpdated,
|
||||
"local_hash": s.localHash[:min(8, len(s.localHash))],
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -359,12 +359,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
// GetStats 获取兑换码统计信息
|
||||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
|
||||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
|
||||
// TODO: 实现统计逻辑
|
||||
// 统计未使用、已使用的兑换码数量
|
||||
// 统计总面值等
|
||||
|
||||
stats := map[string]interface{}{
|
||||
stats := map[string]any{
|
||||
"total_codes": 0,
|
||||
"unused_codes": 0,
|
||||
"used_codes": 0,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type TokenRefresher interface {
|
||||
|
||||
// Refresh 执行token刷新,返回更新后的credentials
|
||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]any, error)
|
||||
}
|
||||
|
||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||
@@ -61,14 +61,14 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留现有credentials中的所有字段
|
||||
newCredentials := make(map[string]interface{})
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ var (
|
||||
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
type TurnstileVerifier interface {
|
||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
// Download archive
|
||||
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
||||
@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
backupPath := exePath + ".backup"
|
||||
|
||||
// Remove old backup if exists
|
||||
os.Remove(backupPath)
|
||||
_ = os.Remove(backupPath)
|
||||
|
||||
// Step 1: Move current binary to backup
|
||||
if err := os.Rename(exePath, backupPath); err != nil {
|
||||
@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
var reader io.Reader = f
|
||||
|
||||
@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzr.Close()
|
||||
defer func() { _ = gzr.Close() }()
|
||||
reader = gzr
|
||||
}
|
||||
|
||||
@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
// Use LimitReader to prevent decompression bombs
|
||||
limited := io.LimitReader(tr, maxBinarySize)
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
out.Close()
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
limited := io.LimitReader(reader, maxBinarySize)
|
||||
_, err = io.Copy(out, limited)
|
||||
return err
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
return out.Close()
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
_ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int {
|
||||
parts := strings.Split(v, ".")
|
||||
result := [3]int{0, 0, 0}
|
||||
for i := 0; i < len(parts) && i < 3; i++ {
|
||||
fmt.Sscanf(parts[i], "%d", &result[i])
|
||||
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||
result[i] = parsed
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, st
|
||||
}
|
||||
|
||||
// GetDailyStats 获取每日使用统计(最近N天)
|
||||
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
|
||||
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]any, error) {
|
||||
endTime := time.Now()
|
||||
startTime := endTime.AddDate(0, 0, -days)
|
||||
|
||||
@@ -227,13 +227,13 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
}
|
||||
|
||||
// 计算平均值并转换为数组
|
||||
result := make([]map[string]interface{}, 0, len(dailyStats))
|
||||
result := make([]map[string]any, 0, len(dailyStats))
|
||||
for date, stats := range dailyStats {
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
@@ -34,14 +33,12 @@ type ChangePasswordRequest struct {
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
|
||||
func NewUserService(userRepo ports.UserRepository) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -352,4 +352,3 @@ func install(c *gin.Context) {
|
||||
"restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config paths
|
||||
@@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get db instance: %w", err)
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
defer func() {
|
||||
if sqlDB == nil {
|
||||
return
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
}
|
||||
|
||||
// Now connect to the target database to verify
|
||||
sqlDB.Close()
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
sqlDB = nil
|
||||
|
||||
targetDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
@@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target db instance: %w", err)
|
||||
}
|
||||
defer targetSqlDB.Close()
|
||||
defer func() {
|
||||
if err := targetSqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
@@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error {
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
defer rdb.Close()
|
||||
defer func() {
|
||||
if err := rdb.Close(); err != nil {
|
||||
log.Printf("failed to close redis client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error {
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
cfg.JWT.Secret = generateSecret(32)
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
}
|
||||
|
||||
// Test connections
|
||||
@@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
defer func() {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 使用 model 包的 AutoMigrate,确保模型定义统一
|
||||
return model.AutoMigrate(db)
|
||||
@@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
defer func() {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if admin already exists
|
||||
var count int64
|
||||
@@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
return os.WriteFile(ConfigFile, data, 0600)
|
||||
}
|
||||
|
||||
func generateSecret(length int) string {
|
||||
func generateSecret(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
@@ -431,13 +463,21 @@ func AutoSetupFromEnv() error {
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
cfg.JWT.Secret = generateSecret(32)
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Generated JWT secret automatically")
|
||||
}
|
||||
|
||||
// Generate admin password if not provided
|
||||
if cfg.Admin.Password == "" {
|
||||
cfg.Admin.Password = generateSecret(16)
|
||||
password, err := generateSecret(16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate admin password: %w", err)
|
||||
}
|
||||
cfg.Admin.Password = password
|
||||
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
||||
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
}
|
||||
|
||||
20
backend/internal/web/embed_off.go
Normal file
20
backend/internal/web/embed_off.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build !embed
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func HasEmbeddedFrontend() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build embed
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
@@ -13,8 +15,6 @@ import (
|
||||
//go:embed all:dist
|
||||
var frontendFS embed.FS
|
||||
|
||||
// ServeEmbeddedFrontend returns a Gin handler that serves embedded frontend assets
|
||||
// and handles SPA routing by falling back to index.html for non-API routes.
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
distFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
@@ -25,7 +25,6 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Skip API and gateway routes
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/setup/") ||
|
||||
@@ -34,20 +33,18 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to serve static file
|
||||
cleanPath := strings.TrimPrefix(path, "/")
|
||||
if cleanPath == "" {
|
||||
cleanPath = "index.html"
|
||||
}
|
||||
|
||||
if file, err := distFS.Open(cleanPath); err == nil {
|
||||
file.Close()
|
||||
_ = file.Close()
|
||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// SPA fallback: serve index.html for all other routes
|
||||
serveIndexHTML(c, distFS)
|
||||
}
|
||||
}
|
||||
@@ -59,7 +56,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
@@ -72,7 +69,6 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// HasEmbeddedFrontend checks if frontend assets are embedded
|
||||
func HasEmbeddedFrontend() bool {
|
||||
_, err := frontendFS.ReadFile("dist/index.html")
|
||||
return err == nil
|
||||
Reference in New Issue
Block a user