mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 00:10:21 +08:00
Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3463769dc | ||
|
|
d9e6cfc44d | ||
|
|
57fd172287 | ||
|
|
8d7a497553 | ||
|
|
b31698b9f2 | ||
|
|
eeaff85e47 | ||
|
|
f51ad2e126 | ||
|
|
f57f12c6cc | ||
|
|
5fca2d10b9 | ||
|
|
8fbe1ad70d | ||
|
|
25a304c231 | ||
|
|
9d30ceae8d | ||
|
|
60f6ed6bf6 | ||
|
|
4a2f7d4a99 | ||
|
|
c19a393be9 | ||
|
|
938ffb002e | ||
|
|
372a01290b | ||
|
|
8b163ca49b | ||
|
|
d23810dc53 | ||
|
|
62ed5422dd | ||
|
|
2e76302af7 | ||
|
|
6553828008 | ||
|
|
adcb7bf00e | ||
|
|
876e85e7ad | ||
|
|
2e7818d688 | ||
|
|
836c4dda2b | ||
|
|
e65e9587b4 | ||
|
|
aaadd6ed04 | ||
|
|
870b21916c | ||
|
|
fb119f9a67 | ||
|
|
ad54795a24 | ||
|
|
0abe322cca | ||
|
|
b071511676 | ||
|
|
7d9a757a26 | ||
|
|
bbf4024dc7 | ||
|
|
5831eb8a6a | ||
|
|
61838cdb3d | ||
|
|
50dba656fd | ||
|
|
0e2821456c | ||
|
|
f25ac3aff5 | ||
|
|
f6341b7f2b | ||
|
|
4e257512b9 | ||
|
|
e53b34f321 | ||
|
|
12ddae0184 | ||
|
|
7b9c3f165e | ||
|
|
0b8e84f942 | ||
|
|
d9e27df9af | ||
|
|
f0fabf89a1 | ||
|
|
5bbfbcdae9 | ||
|
|
eb55947ec4 | ||
|
|
5f7e5184eb | ||
|
|
008a111268 | ||
|
|
fda753278c | ||
|
|
6c469b42ed | ||
|
|
dacf3a2a6e | ||
|
|
e6add93ae3 | ||
|
|
b2273ec695 | ||
|
|
aa89777dda | ||
|
|
1e1f3c0c74 | ||
|
|
1fab9204eb | ||
|
|
dbd3e71637 | ||
|
|
974f67211b | ||
|
|
0338c83b90 | ||
|
|
c6b3de1199 |
41
.github/workflows/backend-ci.yml
vendored
Normal file
41
.github/workflows/backend-ci.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
- name: Integration tests
|
||||
working-directory: backend
|
||||
run: make test-integration
|
||||
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
args: --timeout=5m
|
||||
working-directory: backend
|
||||
94
.github/workflows/release.yml
vendored
94
.github/workflows/release.yml
vendored
@@ -85,6 +85,19 @@ jobs:
|
||||
go-version: '1.24'
|
||||
cache-dependency-path: backend/go.sum
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Fetch tags with annotations
|
||||
run: |
|
||||
# 确保获取完整的 annotated tag 信息
|
||||
@@ -117,87 +130,16 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
|
||||
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
# ===========================================================================
|
||||
# Docker Build and Push
|
||||
# ===========================================================================
|
||||
docker:
|
||||
needs: [update-version, build-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download VERSION artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: version-file
|
||||
path: backend/cmd/server/
|
||||
|
||||
- name: Download frontend artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: backend/internal/web/dist/
|
||||
|
||||
# Extract version from tag
|
||||
- name: Extract version
|
||||
id: version
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Version: $VERSION"
|
||||
|
||||
# Set up Docker Buildx for multi-platform builds
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Login to DockerHub
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Extract metadata for Docker
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
weishaw/sub2api
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
# Build and push Docker image
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.version }}
|
||||
COMMIT=${{ github.sha }}
|
||||
DATE=${{ github.event.head_commit.timestamp }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# Update DockerHub description (optional)
|
||||
# Update DockerHub description
|
||||
- name: Update DockerHub description
|
||||
uses: peter-evans/dockerhub-description@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
repository: weishaw/sub2api
|
||||
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
|
||||
short-description: "Sub2API - AI API Gateway Platform"
|
||||
readme-filepath: ./deploy/DOCKER.md
|
||||
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -28,6 +28,7 @@ node_modules/
|
||||
frontend/node_modules/
|
||||
frontend/dist/
|
||||
*.local
|
||||
*.tsbuildinfo
|
||||
|
||||
# 日志
|
||||
npm-debug.log*
|
||||
@@ -81,15 +82,27 @@ build/
|
||||
release/
|
||||
|
||||
# 后端嵌入的前端构建产物
|
||||
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
|
||||
# while still ignoring generated frontend build outputs.
|
||||
backend/internal/web/dist/
|
||||
!backend/internal/web/dist/
|
||||
backend/internal/web/dist/*
|
||||
!backend/internal/web/dist/.keep
|
||||
|
||||
# 后端运行时缓存数据
|
||||
backend/data/
|
||||
|
||||
# ===================
|
||||
# 本地配置文件(包含敏感信息)
|
||||
# ===================
|
||||
backend/config.yaml
|
||||
deploy/config.yaml
|
||||
backend/.installed
|
||||
|
||||
# ===================
|
||||
# 其他
|
||||
# ===================
|
||||
tests
|
||||
CLAUDE.md
|
||||
.claude
|
||||
scripts
|
||||
scripts
|
||||
|
||||
@@ -11,6 +11,8 @@ builds:
|
||||
dir: backend
|
||||
main: ./cmd/server
|
||||
binary: sub2api
|
||||
flags:
|
||||
- -tags=embed
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
@@ -50,10 +52,58 @@ changelog:
|
||||
# 禁用自动 changelog,完全使用 tag 消息
|
||||
disable: true
|
||||
|
||||
# Docker images
|
||||
dockers:
|
||||
- id: amd64
|
||||
goos: linux
|
||||
goarch: amd64
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
|
||||
- id: arm64
|
||||
goos: linux
|
||||
goarch: arm64
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
|
||||
# Docker manifests for multi-arch support
|
||||
docker_manifests:
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: Wei-Shaw
|
||||
name: sub2api
|
||||
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
|
||||
name: "{{ .Env.GITHUB_REPO_NAME }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "Sub2API {{.Version}}"
|
||||
@@ -71,7 +121,7 @@ release:
|
||||
|
||||
**One-line install (Linux):**
|
||||
```bash
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
**Manual download:**
|
||||
@@ -79,5 +129,5 @@ release:
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
|
||||
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
|
||||
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
|
||||
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -40,14 +40,15 @@ WORKDIR /app/backend
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy frontend dist from previous stage
|
||||
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Copy backend source
|
||||
# Copy backend source first
|
||||
COPY backend/ ./
|
||||
|
||||
# Build the binary (BuildType=release for CI builds)
|
||||
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
|
||||
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-tags embed \
|
||||
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||
-o /app/sub2api \
|
||||
./cmd/server
|
||||
|
||||
40
Dockerfile.goreleaser
Normal file
40
Dockerfile.goreleaser
Normal file
@@ -0,0 +1,40 @@
|
||||
# =============================================================================
|
||||
# Sub2API Dockerfile for GoReleaser
|
||||
# =============================================================================
|
||||
# This Dockerfile is used by GoReleaser to build Docker images.
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy pre-built binary from GoReleaser
|
||||
COPY sub2api /app/sub2api
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
USER sub2api
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
16
README.md
16
README.md
@@ -220,21 +220,21 @@ cd sub2api
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# Output will be in ../backend/internal/web/dist/
|
||||
|
||||
# 3. Copy frontend build to backend (for embedding)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 4. Build backend (requires frontend dist to be present)
|
||||
# 3. Build backend with embedded frontend
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. Create configuration file
|
||||
# 4. Create configuration file
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. Edit configuration
|
||||
# 5. Edit configuration
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
|
||||
|
||||
**Key configuration in `config.yaml`:**
|
||||
|
||||
```yaml
|
||||
@@ -265,7 +265,7 @@ default:
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. Run the application
|
||||
# 6. Run the application
|
||||
./sub2api
|
||||
```
|
||||
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -220,21 +220,21 @@ cd sub2api
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# 构建产物输出到 ../backend/internal/web/dist/
|
||||
|
||||
# 3. 复制前端构建产物到后端(用于嵌入)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 4. 编译后端(需要前端 dist 目录存在)
|
||||
# 3. 编译后端(嵌入前端)
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
|
||||
# 5. 创建配置文件
|
||||
# 4. 创建配置文件
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. 编辑配置
|
||||
# 5. 编辑配置
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
|
||||
|
||||
**`config.yaml` 关键配置:**
|
||||
|
||||
```yaml
|
||||
@@ -265,7 +265,7 @@ default:
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. 运行应用
|
||||
# 6. 运行应用
|
||||
./sub2api
|
||||
```
|
||||
|
||||
|
||||
596
backend/.golangci.yml
Normal file
596
backend/.golangci.yml
Normal file
@@ -0,0 +1,596 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- depguard
|
||||
- errcheck
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- unused
|
||||
|
||||
settings:
|
||||
depguard:
|
||||
rules:
|
||||
# Enforce: service must not depend on repository.
|
||||
service-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- "**/internal/service/**"
|
||||
deny:
|
||||
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
||||
desc: "service must not import repository"
|
||||
- pkg: gorm.io/gorm
|
||||
desc: "service must not import gorm"
|
||||
handler-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- "**/internal/handler/**"
|
||||
deny:
|
||||
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
||||
desc: "handler must not import repository"
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: true
|
||||
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-blank: false
|
||||
# To disable the errcheck built-in exclude list.
|
||||
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
# Default: false
|
||||
disable-default-exclusions: true
|
||||
# List of functions to exclude from checking, where each entry is a single function to exclude.
|
||||
# See https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
exclude-functions:
|
||||
- io/ioutil.ReadFile
|
||||
- io.Copy(*bytes.Buffer)
|
||||
- io.Copy(os.Stdout)
|
||||
- fmt.Println
|
||||
- fmt.Print
|
||||
- fmt.Printf
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintf
|
||||
- fmt.Fprintln
|
||||
# Display function signature instead of selector.
|
||||
# Default: false
|
||||
verbose: true
|
||||
ineffassign:
|
||||
# Check escaping variables of type error, may cause false positives.
|
||||
# Default: false
|
||||
check-escaping-errors: true
|
||||
staticcheck:
|
||||
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||
dot-import-whitelist:
|
||||
- fmt
|
||||
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||
# Default: ["200", "400", "404", "500"]
|
||||
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||
checks:
|
||||
# Invalid regular expression.
|
||||
# https://staticcheck.dev/docs/checks/#SA1000
|
||||
- SA1000
|
||||
# Invalid template.
|
||||
# https://staticcheck.dev/docs/checks/#SA1001
|
||||
- SA1001
|
||||
# Invalid format in 'time.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1002
|
||||
- SA1002
|
||||
# Unsupported argument to functions in 'encoding/binary'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1003
|
||||
- SA1003
|
||||
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1004
|
||||
- SA1004
|
||||
# Invalid first argument to 'exec.Command'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1005
|
||||
- SA1005
|
||||
# 'Printf' with dynamic first argument and no further arguments.
|
||||
# https://staticcheck.dev/docs/checks/#SA1006
|
||||
- SA1006
|
||||
# Invalid URL in 'net/url.Parse'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1007
|
||||
- SA1007
|
||||
# Non-canonical key in 'http.Header' map.
|
||||
# https://staticcheck.dev/docs/checks/#SA1008
|
||||
- SA1008
|
||||
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||
# https://staticcheck.dev/docs/checks/#SA1010
|
||||
- SA1010
|
||||
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||
# https://staticcheck.dev/docs/checks/#SA1011
|
||||
- SA1011
|
||||
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA1012
|
||||
- SA1012
|
||||
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||
# https://staticcheck.dev/docs/checks/#SA1013
|
||||
- SA1013
|
||||
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1014
|
||||
- SA1014
|
||||
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1015
|
||||
- SA1015
|
||||
# Trapping a signal that cannot be trapped.
|
||||
# https://staticcheck.dev/docs/checks/#SA1016
|
||||
- SA1016
|
||||
# Channels used with 'os/signal.Notify' should be buffered.
|
||||
# https://staticcheck.dev/docs/checks/#SA1017
|
||||
- SA1017
|
||||
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||
# https://staticcheck.dev/docs/checks/#SA1018
|
||||
- SA1018
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- SA1019
|
||||
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1020
|
||||
- SA1020
|
||||
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1021
|
||||
- SA1021
|
||||
# Modifying the buffer in an 'io.Writer' implementation.
|
||||
# https://staticcheck.dev/docs/checks/#SA1023
|
||||
- SA1023
|
||||
# A string cutset contains duplicate characters.
|
||||
# https://staticcheck.dev/docs/checks/#SA1024
|
||||
- SA1024
|
||||
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||
# https://staticcheck.dev/docs/checks/#SA1025
|
||||
- SA1025
|
||||
# Cannot marshal channels or functions.
|
||||
# https://staticcheck.dev/docs/checks/#SA1026
|
||||
- SA1026
|
||||
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||
# https://staticcheck.dev/docs/checks/#SA1027
|
||||
- SA1027
|
||||
# 'sort.Slice' can only be used on slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA1028
|
||||
- SA1028
|
||||
# Inappropriate key in call to 'context.WithValue'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1029
|
||||
- SA1029
|
||||
# Invalid argument in call to a 'strconv' function.
|
||||
# https://staticcheck.dev/docs/checks/#SA1030
|
||||
- SA1030
|
||||
# Overlapping byte slices passed to an encoder.
|
||||
# https://staticcheck.dev/docs/checks/#SA1031
|
||||
- SA1031
|
||||
# Wrong order of arguments to 'errors.Is'.
|
||||
# https://staticcheck.dev/docs/checks/#SA1032
|
||||
- SA1032
|
||||
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||
# https://staticcheck.dev/docs/checks/#SA2000
|
||||
- SA2000
|
||||
# Empty critical section, did you mean to defer the unlock?.
|
||||
# https://staticcheck.dev/docs/checks/#SA2001
|
||||
- SA2001
|
||||
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||
# https://staticcheck.dev/docs/checks/#SA2002
|
||||
- SA2002
|
||||
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||
# https://staticcheck.dev/docs/checks/#SA2003
|
||||
- SA2003
|
||||
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||
# https://staticcheck.dev/docs/checks/#SA3000
|
||||
- SA3000
|
||||
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||
# https://staticcheck.dev/docs/checks/#SA3001
|
||||
- SA3001
|
||||
# Binary operator has identical expressions on both sides.
|
||||
# https://staticcheck.dev/docs/checks/#SA4000
|
||||
- SA4000
|
||||
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||
# https://staticcheck.dev/docs/checks/#SA4001
|
||||
- SA4001
|
||||
# Comparing unsigned values against negative values is pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4003
|
||||
- SA4003
|
||||
# The loop exits unconditionally after one iteration.
|
||||
# https://staticcheck.dev/docs/checks/#SA4004
|
||||
- SA4004
|
||||
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4005
|
||||
- SA4005
|
||||
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4006
|
||||
- SA4006
|
||||
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4008
|
||||
- SA4008
|
||||
# A function argument is overwritten before its first use.
|
||||
# https://staticcheck.dev/docs/checks/#SA4009
|
||||
- SA4009
|
||||
# The result of 'append' will never be observed anywhere.
|
||||
# https://staticcheck.dev/docs/checks/#SA4010
|
||||
- SA4010
|
||||
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||
# https://staticcheck.dev/docs/checks/#SA4011
|
||||
- SA4011
|
||||
# Comparing a value against NaN even though no value is equal to NaN.
|
||||
# https://staticcheck.dev/docs/checks/#SA4012
|
||||
- SA4012
|
||||
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||
# https://staticcheck.dev/docs/checks/#SA4013
|
||||
- SA4013
|
||||
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||
# https://staticcheck.dev/docs/checks/#SA4014
|
||||
- SA4014
|
||||
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4015
|
||||
- SA4015
|
||||
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||
# https://staticcheck.dev/docs/checks/#SA4016
|
||||
- SA4016
|
||||
# Discarding the return values of a function without side effects, making the call pointless.
|
||||
# https://staticcheck.dev/docs/checks/#SA4017
|
||||
- SA4017
|
||||
# Self-assignment of variables.
|
||||
# https://staticcheck.dev/docs/checks/#SA4018
|
||||
- SA4018
|
||||
# Multiple, identical build constraints in the same file.
|
||||
# https://staticcheck.dev/docs/checks/#SA4019
|
||||
- SA4019
|
||||
# Unreachable case clause in a type switch.
|
||||
# https://staticcheck.dev/docs/checks/#SA4020
|
||||
- SA4020
|
||||
# "x = append(y)" is equivalent to "x = y".
|
||||
# https://staticcheck.dev/docs/checks/#SA4021
|
||||
- SA4021
|
||||
# Comparing the address of a variable against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4022
|
||||
- SA4022
|
||||
# Impossible comparison of interface value with untyped nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4023
|
||||
- SA4023
|
||||
# Checking for impossible return value from a builtin function.
|
||||
# https://staticcheck.dev/docs/checks/#SA4024
|
||||
- SA4024
|
||||
# Integer division of literals that results in zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4025
|
||||
- SA4025
|
||||
# Go constants cannot express negative zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4026
|
||||
- SA4026
|
||||
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||
# https://staticcheck.dev/docs/checks/#SA4027
|
||||
- SA4027
|
||||
# 'x % 1' is always zero.
|
||||
# https://staticcheck.dev/docs/checks/#SA4028
|
||||
- SA4028
|
||||
# Ineffective attempt at sorting slice.
|
||||
# https://staticcheck.dev/docs/checks/#SA4029
|
||||
- SA4029
|
||||
# Ineffective attempt at generating random number.
|
||||
# https://staticcheck.dev/docs/checks/#SA4030
|
||||
- SA4030
|
||||
# Checking never-nil value against nil.
|
||||
# https://staticcheck.dev/docs/checks/#SA4031
|
||||
- SA4031
|
||||
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||
# https://staticcheck.dev/docs/checks/#SA4032
|
||||
- SA4032
|
||||
# Assignment to nil map.
|
||||
# https://staticcheck.dev/docs/checks/#SA5000
|
||||
- SA5000
|
||||
# Deferring 'Close' before checking for a possible error.
|
||||
# https://staticcheck.dev/docs/checks/#SA5001
|
||||
- SA5001
|
||||
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||
# https://staticcheck.dev/docs/checks/#SA5002
|
||||
- SA5002
|
||||
# Defers in infinite loops will never execute.
|
||||
# https://staticcheck.dev/docs/checks/#SA5003
|
||||
- SA5003
|
||||
# "for { select { ..." with an empty default branch spins.
|
||||
# https://staticcheck.dev/docs/checks/#SA5004
|
||||
- SA5004
|
||||
# The finalizer references the finalized object, preventing garbage collection.
|
||||
# https://staticcheck.dev/docs/checks/#SA5005
|
||||
- SA5005
|
||||
# Infinite recursive call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5007
|
||||
- SA5007
|
||||
# Invalid struct tag.
|
||||
# https://staticcheck.dev/docs/checks/#SA5008
|
||||
- SA5008
|
||||
# Invalid Printf call.
|
||||
# https://staticcheck.dev/docs/checks/#SA5009
|
||||
- SA5009
|
||||
# Impossible type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#SA5010
|
||||
- SA5010
|
||||
# Possible nil pointer dereference.
|
||||
# https://staticcheck.dev/docs/checks/#SA5011
|
||||
- SA5011
|
||||
# Passing odd-sized slice to function expecting even size.
|
||||
# https://staticcheck.dev/docs/checks/#SA5012
|
||||
- SA5012
|
||||
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6000
|
||||
- SA6000
|
||||
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||
# https://staticcheck.dev/docs/checks/#SA6001
|
||||
- SA6001
|
||||
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||
# https://staticcheck.dev/docs/checks/#SA6002
|
||||
- SA6002
|
||||
# Converting a string to a slice of runes before ranging over it.
|
||||
# https://staticcheck.dev/docs/checks/#SA6003
|
||||
- SA6003
|
||||
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6005
|
||||
- SA6005
|
||||
# Using io.WriteString to write '[]byte'.
|
||||
# https://staticcheck.dev/docs/checks/#SA6006
|
||||
- SA6006
|
||||
# Defers in range loops may not run when you expect them to.
|
||||
# https://staticcheck.dev/docs/checks/#SA9001
|
||||
- SA9001
|
||||
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||
# https://staticcheck.dev/docs/checks/#SA9002
|
||||
- SA9002
|
||||
# Empty body in an if or else branch.
|
||||
# https://staticcheck.dev/docs/checks/#SA9003
|
||||
- SA9003
|
||||
# Only the first constant has an explicit type.
|
||||
# https://staticcheck.dev/docs/checks/#SA9004
|
||||
- SA9004
|
||||
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||
# https://staticcheck.dev/docs/checks/#SA9005
|
||||
- SA9005
|
||||
# Dubious bit shifting of a fixed size integer value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9006
|
||||
- SA9006
|
||||
# Deleting a directory that shouldn't be deleted.
|
||||
# https://staticcheck.dev/docs/checks/#SA9007
|
||||
- SA9007
|
||||
# 'else' branch of a type assertion is probably not reading the right value.
|
||||
# https://staticcheck.dev/docs/checks/#SA9008
|
||||
- SA9008
|
||||
# Ineffectual Go compiler directive.
|
||||
# https://staticcheck.dev/docs/checks/#SA9009
|
||||
- SA9009
|
||||
# Incorrect or missing package comment.
|
||||
# https://staticcheck.dev/docs/checks/#ST1000
|
||||
- ST1000
|
||||
# Dot imports are discouraged.
|
||||
# https://staticcheck.dev/docs/checks/#ST1001
|
||||
- ST1001
|
||||
# Poorly chosen identifier.
|
||||
# https://staticcheck.dev/docs/checks/#ST1003
|
||||
- ST1003
|
||||
# Incorrectly formatted error string.
|
||||
# https://staticcheck.dev/docs/checks/#ST1005
|
||||
- ST1005
|
||||
# Poorly chosen receiver name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1006
|
||||
- ST1006
|
||||
# A function's error value should be its last return value.
|
||||
# https://staticcheck.dev/docs/checks/#ST1008
|
||||
- ST1008
|
||||
# Poorly chosen name for variable of type 'time.Duration'.
|
||||
# https://staticcheck.dev/docs/checks/#ST1011
|
||||
- ST1011
|
||||
# Poorly chosen name for error variable.
|
||||
# https://staticcheck.dev/docs/checks/#ST1012
|
||||
- ST1012
|
||||
# Should use constants for HTTP error codes, not magic numbers.
|
||||
# https://staticcheck.dev/docs/checks/#ST1013
|
||||
- ST1013
|
||||
# A switch's default case should be the first or last case.
|
||||
# https://staticcheck.dev/docs/checks/#ST1015
|
||||
- ST1015
|
||||
# Use consistent method receiver names.
|
||||
# https://staticcheck.dev/docs/checks/#ST1016
|
||||
- ST1016
|
||||
# Don't use Yoda conditions.
|
||||
# https://staticcheck.dev/docs/checks/#ST1017
|
||||
- ST1017
|
||||
# Avoid zero-width and control characters in string literals.
|
||||
# https://staticcheck.dev/docs/checks/#ST1018
|
||||
- ST1018
|
||||
# Importing the same package multiple times.
|
||||
# https://staticcheck.dev/docs/checks/#ST1019
|
||||
- ST1019
|
||||
# The documentation of an exported function should start with the function's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1020
|
||||
- ST1020
|
||||
# The documentation of an exported type should start with type's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1021
|
||||
- ST1021
|
||||
# The documentation of an exported variable or constant should start with variable's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1022
|
||||
- ST1022
|
||||
# Redundant type in variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#ST1023
|
||||
- ST1023
|
||||
# Use plain channel send or receive instead of single-case select.
|
||||
# https://staticcheck.dev/docs/checks/#S1000
|
||||
- S1000
|
||||
# Replace for loop with call to copy.
|
||||
# https://staticcheck.dev/docs/checks/#S1001
|
||||
- S1001
|
||||
# Omit comparison with boolean constant.
|
||||
# https://staticcheck.dev/docs/checks/#S1002
|
||||
- S1002
|
||||
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||
# https://staticcheck.dev/docs/checks/#S1003
|
||||
- S1003
|
||||
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||
# https://staticcheck.dev/docs/checks/#S1004
|
||||
- S1004
|
||||
# Drop unnecessary use of the blank identifier.
|
||||
# https://staticcheck.dev/docs/checks/#S1005
|
||||
- S1005
|
||||
# Use "for { ... }" for infinite loops.
|
||||
# https://staticcheck.dev/docs/checks/#S1006
|
||||
- S1006
|
||||
# Simplify regular expression by using raw string literal.
|
||||
# https://staticcheck.dev/docs/checks/#S1007
|
||||
- S1007
|
||||
# Simplify returning boolean expression.
|
||||
# https://staticcheck.dev/docs/checks/#S1008
|
||||
- S1008
|
||||
# Omit redundant nil check on slices, maps, and channels.
|
||||
# https://staticcheck.dev/docs/checks/#S1009
|
||||
- S1009
|
||||
# Omit default slice index.
|
||||
# https://staticcheck.dev/docs/checks/#S1010
|
||||
- S1010
|
||||
# Use a single 'append' to concatenate two slices.
|
||||
# https://staticcheck.dev/docs/checks/#S1011
|
||||
- S1011
|
||||
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1012
|
||||
- S1012
|
||||
# Use a type conversion instead of manually copying struct fields.
|
||||
# https://staticcheck.dev/docs/checks/#S1016
|
||||
- S1016
|
||||
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||
# https://staticcheck.dev/docs/checks/#S1017
|
||||
- S1017
|
||||
# Use "copy" for sliding elements.
|
||||
# https://staticcheck.dev/docs/checks/#S1018
|
||||
- S1018
|
||||
# Simplify "make" call by omitting redundant arguments.
|
||||
# https://staticcheck.dev/docs/checks/#S1019
|
||||
- S1019
|
||||
# Omit redundant nil check in type assertion.
|
||||
# https://staticcheck.dev/docs/checks/#S1020
|
||||
- S1020
|
||||
# Merge variable declaration and assignment.
|
||||
# https://staticcheck.dev/docs/checks/#S1021
|
||||
- S1021
|
||||
# Omit redundant control flow.
|
||||
# https://staticcheck.dev/docs/checks/#S1023
|
||||
- S1023
|
||||
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1024
|
||||
- S1024
|
||||
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||
# https://staticcheck.dev/docs/checks/#S1025
|
||||
- S1025
|
||||
# Simplify error construction with 'fmt.Errorf'.
|
||||
# https://staticcheck.dev/docs/checks/#S1028
|
||||
- S1028
|
||||
# Range over the string directly.
|
||||
# https://staticcheck.dev/docs/checks/#S1029
|
||||
- S1029
|
||||
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||
# https://staticcheck.dev/docs/checks/#S1030
|
||||
- S1030
|
||||
# Omit redundant nil check around loop.
|
||||
# https://staticcheck.dev/docs/checks/#S1031
|
||||
- S1031
|
||||
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||
# https://staticcheck.dev/docs/checks/#S1032
|
||||
- S1032
|
||||
# Unnecessary guard around call to "delete".
|
||||
# https://staticcheck.dev/docs/checks/#S1033
|
||||
- S1033
|
||||
# Use result of type assertion to simplify cases.
|
||||
# https://staticcheck.dev/docs/checks/#S1034
|
||||
- S1034
|
||||
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||
# https://staticcheck.dev/docs/checks/#S1035
|
||||
- S1035
|
||||
# Unnecessary guard around map access.
|
||||
# https://staticcheck.dev/docs/checks/#S1036
|
||||
- S1036
|
||||
# Elaborate way of sleeping.
|
||||
# https://staticcheck.dev/docs/checks/#S1037
|
||||
- S1037
|
||||
# Unnecessarily complex way of printing formatted string.
|
||||
# https://staticcheck.dev/docs/checks/#S1038
|
||||
- S1038
|
||||
# Unnecessary use of 'fmt.Sprint'.
|
||||
# https://staticcheck.dev/docs/checks/#S1039
|
||||
- S1039
|
||||
# Type assertion to current type.
|
||||
# https://staticcheck.dev/docs/checks/#S1040
|
||||
- S1040
|
||||
# Apply De Morgan's law.
|
||||
# https://staticcheck.dev/docs/checks/#QF1001
|
||||
- QF1001
|
||||
# Convert untagged switch to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1002
|
||||
- QF1002
|
||||
# Convert if/else-if chain to tagged switch.
|
||||
# https://staticcheck.dev/docs/checks/#QF1003
|
||||
- QF1003
|
||||
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1004
|
||||
- QF1004
|
||||
# Expand call to 'math.Pow'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1005
|
||||
- QF1005
|
||||
# Lift 'if'+'break' into loop condition.
|
||||
# https://staticcheck.dev/docs/checks/#QF1006
|
||||
- QF1006
|
||||
# Merge conditional assignment into variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1007
|
||||
- QF1007
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- QF1008
|
||||
# Use 'time.Time.Equal' instead of '==' operator.
|
||||
# https://staticcheck.dev/docs/checks/#QF1009
|
||||
- QF1009
|
||||
# Convert slice of bytes to string when printing it.
|
||||
# https://staticcheck.dev/docs/checks/#QF1010
|
||||
- QF1010
|
||||
# Omit redundant type from variable declaration.
|
||||
# https://staticcheck.dev/docs/checks/#QF1011
|
||||
- QF1011
|
||||
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||
# https://staticcheck.dev/docs/checks/#QF1012
|
||||
- QF1012
|
||||
unused:
|
||||
# Mark all struct fields that have been written to as used.
|
||||
# Default: true
|
||||
field-writes-are-uses: false
|
||||
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||
# Default: false
|
||||
post-statements-are-reads: true
|
||||
# Mark all exported fields as used.
|
||||
# default: true
|
||||
exported-fields-are-used: false
|
||||
# Mark all function parameters as used.
|
||||
# default: true
|
||||
parameters-are-used: true
|
||||
# Mark all local variables as used.
|
||||
# default: true
|
||||
local-variables-are-used: false
|
||||
# Mark all identifiers inside generated files as used.
|
||||
# Default: true
|
||||
generated-is-used: false
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
settings:
|
||||
gofmt:
|
||||
# Simplify code: gofmt with `-s` option.
|
||||
# Default: true
|
||||
simplify: false
|
||||
# Apply the rewrite rules to the source before reformatting.
|
||||
# https://pkg.go.dev/cmd/gofmt
|
||||
# Default: []
|
||||
rewrite-rules:
|
||||
- pattern: 'interface{}'
|
||||
replacement: 'any'
|
||||
- pattern: 'a[b:len(a)]'
|
||||
replacement: 'a[b:]'
|
||||
@@ -1,6 +1,33 @@
|
||||
.PHONY: wire
|
||||
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage
|
||||
|
||||
wire:
|
||||
@echo "生成 Wire 代码..."
|
||||
@cd cmd/server && go generate
|
||||
@echo "Wire 代码生成完成"
|
||||
@echo "Wire 代码生成完成"
|
||||
|
||||
build:
|
||||
@echo "构建后端(不嵌入前端)..."
|
||||
@go build -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server"
|
||||
|
||||
build-embed:
|
||||
@echo "构建后端(嵌入前端)..."
|
||||
@go build -tags embed -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server (with embedded frontend)"
|
||||
|
||||
test-unit:
|
||||
@go test -tags unit ./... -count=1
|
||||
|
||||
test-integration:
|
||||
@go test -tags integration ./... -count=1 -race -parallel=8
|
||||
|
||||
test-cover-integration:
|
||||
@echo "运行集成测试并生成覆盖率报告..."
|
||||
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./...
|
||||
@go tool cover -func=coverage.out | tail -1
|
||||
@go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "覆盖率报告已生成: coverage.html"
|
||||
|
||||
clean-coverage:
|
||||
@rm -f coverage.out coverage.html
|
||||
@echo "覆盖率文件已清理"
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/setup"
|
||||
"sub2api/internal/web"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -84,7 +84,7 @@ func main() {
|
||||
|
||||
func runSetupServer() {
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// Register setup routes
|
||||
|
||||
@@ -4,18 +4,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/infrastructure"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/server"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
@@ -35,6 +36,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// 业务层 ProviderSets
|
||||
repository.ProviderSet,
|
||||
service.ProviderSet,
|
||||
middleware.ProviderSet,
|
||||
handler.ProviderSet,
|
||||
|
||||
// 服务器层 ProviderSet
|
||||
@@ -62,7 +64,11 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
services *service.Services,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -74,15 +80,23 @@ func provideCleanup(
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
services.EmailQueue.Stop()
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
oauth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
|
||||
@@ -8,17 +8,18 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/infrastructure"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/server"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -48,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(authService)
|
||||
userService := service.NewUserService(userRepository, configConfig)
|
||||
userService := service.NewUserService(userRepository)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
@@ -58,7 +59,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||
billingCache := repository.NewBillingCache(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||
@@ -67,22 +68,29 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
dashboardService := service.NewDashboardService(usageLogRepository)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
@@ -92,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
@@ -103,58 +111,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
groupService := service.NewGroupService(groupRepository)
|
||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||
proxyService := service.NewProxyService(proxyRepository)
|
||||
services := &service.Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OAuth: oAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
ApiKey: apiKeyRepository,
|
||||
Group: groupRepository,
|
||||
Account: accountRepository,
|
||||
Proxy: proxyRepository,
|
||||
RedeemCode: redeemCodeRepository,
|
||||
UsageLog: usageLogRepository,
|
||||
Setting: settingRepository,
|
||||
UserSubscription: userSubscriptionRepository,
|
||||
}
|
||||
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
v := provideCleanup(db, client, services)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -179,7 +148,11 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
services *service.Services,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -190,15 +163,23 @@ func provideCleanup(
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
services.EmailQueue.Stop()
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
oauth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "debug" # debug/release
|
||||
|
||||
database:
|
||||
host: "127.0.0.1"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "XZeRr7nkjHWhm8fw"
|
||||
dbname: "sub2api"
|
||||
sslmode: "disable"
|
||||
|
||||
redis:
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
jwt:
|
||||
secret: "your-secret-key-change-in-production"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
admin_email: "admin@sub2api.com"
|
||||
admin_password: "admin123"
|
||||
user_concurrency: 5
|
||||
user_balance: 0
|
||||
api_key_prefix: "sk-"
|
||||
rate_multiplier: 1.0
|
||||
|
||||
# Timezone configuration (similar to PHP's date_default_timezone_set)
|
||||
# This affects ALL time operations:
|
||||
# - Database timestamps
|
||||
# - Usage statistics "today" boundary
|
||||
# - Subscription expiry times
|
||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
||||
timezone: "Asia/Shanghai"
|
||||
@@ -1,4 +1,4 @@
|
||||
module sub2api
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.24.0
|
||||
|
||||
@@ -8,10 +8,16 @@ require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/imroc/req/v3 v3.56.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/spf13/viper v1.18.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/term v0.37.0
|
||||
@@ -21,51 +27,99 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/google/wire v0.7.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.4.3 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.4 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.1 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||
github.com/moby/sys/sequential v0.6.0 // indirect
|
||||
github.com/moby/sys/user v0.4.0 // indirect
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.56.0 // indirect
|
||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
@@ -75,6 +129,7 @@ require (
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
169
backend/go.sum
169
backend/go.sum
@@ -1,3 +1,11 @@
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@@ -7,17 +15,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
@@ -28,6 +62,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -40,9 +81,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
@@ -54,6 +94,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
@@ -64,8 +106,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
||||
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
|
||||
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -85,36 +129,70 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
||||
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
||||
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
|
||||
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
@@ -128,6 +206,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -135,16 +215,55 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
@@ -164,8 +283,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
@@ -177,9 +302,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
|
||||
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
|
||||
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
@@ -192,6 +323,6 @@ gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -52,7 +52,7 @@ type PricingConfig struct {
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func setDefaults() {
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
@@ -210,10 +210,10 @@ func setDefaults() {
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -3,9 +3,12 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -13,14 +16,12 @@ import (
|
||||
// OAuthHandler handles OAuth-related operations for accounts
|
||||
type OAuthHandler struct {
|
||||
oauthService *service.OAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
|
||||
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
oauthService: oauthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,47 +29,81 @@ func NewOAuthHandler(oauthService *service.OAuthService, adminService service.Ad
|
||||
type AccountHandler struct {
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
|
||||
func NewAccountHandler(
|
||||
adminService service.AdminService,
|
||||
oauthService *service.OAuthService,
|
||||
openaiOAuthService *service.OpenAIOAuthService,
|
||||
rateLimitService *service.RateLimitService,
|
||||
accountUsageService *service.AccountUsageService,
|
||||
accountTestService *service.AccountTestService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAccountRequest represents create account request
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials" binding:"required"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents update account request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*model.Account
|
||||
CurrentConcurrency int `json:"current_concurrency"`
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
@@ -82,11 +117,32 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list accounts: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
// Get current concurrency counts for all accounts
|
||||
accountIDs := make([]int64, len(accounts))
|
||||
for i, acc := range accounts {
|
||||
accountIDs[i] = acc.ID
|
||||
}
|
||||
|
||||
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
// Log error but don't fail the request, just use 0 for all
|
||||
concurrencyCounts = make(map[int64]int)
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
result := make([]AccountWithConcurrency, len(accounts))
|
||||
for i := range accounts {
|
||||
result[i] = AccountWithConcurrency{
|
||||
Account: &accounts[i],
|
||||
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting an account by ID
|
||||
@@ -100,7 +156,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,7 +184,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -162,7 +218,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -180,7 +236,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -192,6 +248,13 @@ type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
SyncProxies *bool `json:"sync_proxies"`
|
||||
}
|
||||
|
||||
// Test handles testing account connectivity with SSE streaming
|
||||
// POST /api/v1/admin/accounts/:id/test
|
||||
func (h *AccountHandler) Test(c *gin.Context) {
|
||||
@@ -212,6 +275,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
// POST /api/v1/admin/accounts/sync/crs
|
||||
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||
var req SyncFromCRSRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Default to syncing proxies (can be disabled by explicitly setting false)
|
||||
syncProxies := true
|
||||
if req.SyncProxies != nil {
|
||||
syncProxies = *req.SyncProxies
|
||||
}
|
||||
|
||||
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||
BaseURL: req.BaseURL,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
SyncProxies: syncProxies,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
@@ -234,32 +326,53 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
var newCredentials map[string]any
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials := make(map[string]interface{})
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials = make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -275,15 +388,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
_ = accountID
|
||||
response.Success(c, gin.H{
|
||||
"total_requests": 0,
|
||||
"successful_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"total_tokens": 0,
|
||||
"average_response_time": 0,
|
||||
})
|
||||
// Parse days parameter (default 30)
|
||||
days := 30
|
||||
if daysStr := c.Query("days"); daysStr != "" {
|
||||
if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
|
||||
days = d
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate time range
|
||||
now := timezone.Now()
|
||||
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
|
||||
|
||||
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// ClearError handles clearing account error
|
||||
@@ -297,7 +421,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
|
||||
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to clear error: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -323,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUpdateCredentialsRequest represents batch credentials update request
|
||||
type BatchUpdateCredentialsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
// BatchUpdateCredentials handles batch updating credentials fields
|
||||
// POST /api/v1/admin/accounts/batch-update-credentials
|
||||
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
var req BatchUpdateCredentialsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate value type based on field
|
||||
if req.Field == "intercept_warmup_requests" {
|
||||
// Must be boolean
|
||||
if _, ok := req.Value.(bool); !ok {
|
||||
response.BadRequest(c, "intercept_warmup_requests must be boolean")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// account_uuid and org_uuid can be string or null
|
||||
if req.Value != nil {
|
||||
if _, ok := req.Value.(string); !ok {
|
||||
response.BadRequest(c, req.Field+" must be string or null")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
|
||||
// Update account
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
|
||||
// POST /api/v1/admin/accounts/bulk-update
|
||||
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
var req BulkUpdateAccountsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hasUpdates := req.Name != "" ||
|
||||
req.ProxyID != nil ||
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.Status != "" ||
|
||||
req.GroupIDs != nil ||
|
||||
len(req.Credentials) > 0 ||
|
||||
len(req.Extra) > 0
|
||||
|
||||
if !hasUpdates {
|
||||
response.BadRequest(c, "No updates provided")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
@@ -341,7 +595,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
|
||||
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -359,7 +613,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
|
||||
|
||||
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -388,7 +642,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -410,7 +664,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -438,7 +692,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
|
||||
Scope: "full",
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Cookie auth failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -460,7 +714,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
||||
Scope: "inference",
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Cookie auth failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -478,7 +732,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
|
||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -496,7 +750,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
|
||||
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -514,7 +768,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get today stats: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -543,7 +797,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
|
||||
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -565,6 +819,46 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// Return mapped models
|
||||
var models []openai.Model
|
||||
for requestedModel := range mapping {
|
||||
var found bool
|
||||
for _, dm := range openai.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
models = append(models, openai.Model{
|
||||
ID: requestedModel,
|
||||
Object: "model",
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
response.Success(c, models)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, claude.DefaultModels)
|
||||
@@ -573,7 +867,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// For API Key accounts: return models based on model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
// No mapping configured, return default models
|
||||
response.Success(c, claude.DefaultModels)
|
||||
return
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strconv"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,17 +12,15 @@ import (
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
adminService service.AdminService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
dashboardService *service.DashboardService
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
adminService: adminService,
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
dashboardService: dashboardService,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +58,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
// GetStats handles getting dashboard statistics
|
||||
// GET /api/v1/admin/dashboard/stats
|
||||
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
@@ -110,6 +107,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
// 系统运行统计
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"uptime": uptime,
|
||||
|
||||
// 性能指标
|
||||
"rpm": stats.Rpm,
|
||||
"tpm": stats.Tpm,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -145,7 +146,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -178,7 +179,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
@@ -203,7 +204,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
@@ -229,7 +230,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
@@ -258,11 +259,11 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
@@ -286,11 +287,11 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -3,9 +3,9 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -65,7 +65,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list groups: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get groups: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
|
||||
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Group not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create group: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update group: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete group: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -225,7 +225,7 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
|
||||
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get group API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
|
||||
type OpenAIOAuthHandler struct {
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
|
||||
type OpenAIGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates OpenAI OAuth authorization URL
|
||||
// POST /api/v1/admin/openai/generate-auth-url
|
||||
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req OpenAIGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges OpenAI authorization code for tokens
|
||||
// POST /api/v1/admin/openai/exchange-code
|
||||
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req OpenAIExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedAccount)
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -53,7 +53,7 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
|
||||
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list proxies: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
if withCount {
|
||||
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxies: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, proxies)
|
||||
@@ -78,7 +78,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
|
||||
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxies: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
|
||||
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Proxy not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
|
||||
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to test proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -229,14 +229,13 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
|
||||
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
}
|
||||
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
@@ -273,7 +272,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
||||
// Check for duplicates (same host, port, username, password)
|
||||
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -43,7 +43,7 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
|
||||
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Redeem code not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
|
||||
|
||||
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
|
||||
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_value_distributed": 0.0,
|
||||
"by_type": gin.H{
|
||||
"balance": 0,
|
||||
@@ -178,7 +178,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
// Get all codes without pagination (use large page size)
|
||||
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Write data rows
|
||||
for _, code := range codes {
|
||||
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
writer.Write([]string{
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
code.Type,
|
||||
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
usedBy,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -27,7 +27,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ type UpdateSettingsRequest struct {
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -104,19 +105,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocUrl: req.DocUrl,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
response.InternalError(c, "Failed to update settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get updated settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -164,7 +166,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -250,7 +252,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -262,7 +264,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,7 +279,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -290,7 +292,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -78,7 +78,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
|
||||
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Subscription not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to extend subscription: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -230,7 +230,7 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/sysutil"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -4,35 +4,32 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageRepo *repository.UsageLogRepository,
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageRepo: usageRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
usageService: usageService,
|
||||
adminService: adminService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,16 +81,16 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := repository.UsageLogFilters{
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -161,7 +158,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
@@ -171,7 +168,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
@@ -179,9 +176,9 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -193,14 +190,14 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []interface{}{})
|
||||
response.Success(c, []any{})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to 30 results
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search users: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -237,9 +234,9 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -25,6 +25,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Wechat string `json:"wechat"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
@@ -35,6 +38,9 @@ type CreateUserRequest struct {
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
@@ -43,8 +49,9 @@ type UpdateUserRequest struct {
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
type UpdateBalanceRequest struct {
|
||||
Balance float64 `json:"balance" binding:"required"`
|
||||
Balance float64 `json:"balance" binding:"required,gt=0"`
|
||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all users with pagination
|
||||
@@ -57,7 +64,7 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list users: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,7 +82,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
|
||||
user, err := h.adminService.GetUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "User not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -94,12 +101,15 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create user: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -125,13 +135,16 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update user: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -149,7 +162,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete user: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,9 +184,9 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update balance: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -193,7 +206,7 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
|
||||
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,7 +226,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
|
||||
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user usage: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ package handler
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -57,7 +57,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
|
||||
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,7 +227,7 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get available groups: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -66,14 +66,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Registration failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,13 +95,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to send verification code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -122,13 +122,13 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
response.Unauthorized(c, "Login failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,27 +10,21 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum wait time for concurrency slot
|
||||
maxConcurrencyWait = 60 * time.Second
|
||||
// Ping interval during wait
|
||||
pingInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
userService *service.UserService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -38,8 +32,8 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
userService: userService,
|
||||
concurrencyService: concurrencyService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,13 +41,13 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -85,11 +79,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -98,10 +92,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -139,7 +133,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -173,146 +167,38 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// concurrencyError represents a concurrency limit error with context
|
||||
type concurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *concurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
|
||||
// Note: For streaming requests, we send ping to keep the connection alive.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
|
||||
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// For streaming requests, set up SSE headers for ping
|
||||
var flusher http.Flusher
|
||||
if isStream {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &concurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingTicker.C:
|
||||
// Send ping for streaming requests to keep connection alive
|
||||
if isStream && flusher != nil {
|
||||
// Set headers on first ping (lazy initialization)
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
// GET /v1/models
|
||||
// Returns different model lists based on the API key's group platform
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
||||
|
||||
// Return OpenAI models for OpenAI platform groups
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": openai.DefaultModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Default: Claude models
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": claude.DefaultModels,
|
||||
"object": "list",
|
||||
"data": claude.DefaultModels,
|
||||
})
|
||||
}
|
||||
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -320,7 +206,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
|
||||
// 订阅模式:返回订阅限额信息
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
subscription, ok := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, ok := middleware2.GetSubscriptionFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
|
||||
return
|
||||
@@ -414,7 +300,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
if ok {
|
||||
// Send error event in SSE format
|
||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
fmt.Fprint(c.Writer, errorEvent)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
@@ -440,13 +328,13 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
|
||||
// 特点:校验订阅/余额,但不计算并发、不记录使用量
|
||||
func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -474,7 +362,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
@@ -574,11 +462,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
|
||||
180
backend/internal/handler/gateway_helper.go
Normal file
180
backend/internal/handler/gateway_helper.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval is the interval for sending ping events during slot wait
|
||||
pingInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
type SSEPingFormat string
|
||||
|
||||
const (
|
||||
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
||||
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
||||
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
||||
SSEPingFormatNone SSEPingFormat = ""
|
||||
)
|
||||
|
||||
// ConcurrencyError represents a concurrency limit error with context
|
||||
type ConcurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *ConcurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
||||
type ConcurrencyHelper struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
pingFormat SSEPingFormat
|
||||
}
|
||||
|
||||
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
||||
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
||||
return &ConcurrencyHelper{
|
||||
concurrencyService: concurrencyService,
|
||||
pingFormat: pingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait count for a user
|
||||
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Only create ping ticker if ping is needed
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &ConcurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingCh:
|
||||
// Send ping to keep connection alive
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
)
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
@@ -21,15 +22,16 @@ type AdminHandlers struct {
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
Setting *SettingHandler
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
209
backend/internal/handler/openai_gateway_handler.go
Normal file
209
backend/internal/handler/openai_gateway_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body to map for potential modification
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// For non-Codex CLI requests, set default instructions
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
reqBody["instructions"] = openai.DefaultInstructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// Ensure wait count is decremented when function exits
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Normal case: return JSON response with proper status code
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -57,7 +57,7 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to redeem code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (h *RedeemHandler) GetHistory(c *gin.Context) {
|
||||
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get history: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -26,7 +26,7 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
|
||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -58,7 +58,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (h *SubscriptionHandler) GetActive(c *gin.Context) {
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
// Get all active subscriptions with progress
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
|
||||
// Get all active subscriptions
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,11 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -17,15 +16,13 @@ import (
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
usageRepo: usageRepo,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
@@ -58,7 +55,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != user.ID {
|
||||
@@ -80,7 +77,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -110,7 +107,7 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
|
||||
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Usage record not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,7 +204,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -260,9 +257,9 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get dashboard statistics")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -287,9 +284,9 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage trend")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -318,9 +315,9 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
|
||||
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get model statistics")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -358,14 +355,14 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to verify API key ownership")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -383,13 +380,13 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get API key usage stats")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -26,6 +26,12 @@ type ChangePasswordRequest struct {
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
@@ -43,10 +49,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user profile: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, userData)
|
||||
}
|
||||
|
||||
@@ -77,9 +86,46 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
}
|
||||
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to change password: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// UpdateProfile handles updating user profile
|
||||
// PUT /api/v1/users/me
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, updatedUser)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
@@ -27,6 +28,7 @@ func ProvideAdminHandlers(
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
@@ -56,18 +58,20 @@ func ProvideHandlers(
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +85,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -89,6 +94,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
|
||||
158
backend/internal/infrastructure/errors/errors.go
Normal file
158
backend/internal/infrastructure/errors/errors.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownCode = http.StatusInternalServerError
|
||||
UnknownReason = ""
|
||||
UnknownMessage = "internal error"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
Code int32 `json:"code"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ApplicationError is the standard error type used to control HTTP responses.
|
||||
//
|
||||
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
|
||||
type ApplicationError struct {
|
||||
Status
|
||||
cause error
|
||||
}
|
||||
|
||||
// Error is kept for backwards compatibility within this package.
|
||||
type Error = ApplicationError
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if e == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
if e.cause == nil {
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
|
||||
}
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
|
||||
}
|
||||
|
||||
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||
func (e *ApplicationError) Unwrap() error { return e.cause }
|
||||
|
||||
// Is matches each error in the chain with the target value.
|
||||
func (e *ApplicationError) Is(err error) bool {
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se.Code == e.Code && se.Reason == e.Reason
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithCause attaches the underlying cause of the error.
|
||||
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
|
||||
err := Clone(e)
|
||||
err.cause = cause
|
||||
return err
|
||||
}
|
||||
|
||||
// WithMetadata deep-copies the given metadata map.
|
||||
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
|
||||
err := Clone(e)
|
||||
if md == nil {
|
||||
err.Metadata = nil
|
||||
return err
|
||||
}
|
||||
err.Metadata = make(map[string]string, len(md))
|
||||
for k, v := range md {
|
||||
err.Metadata[k] = v
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// New returns an error object for the code, message.
|
||||
func New(code int, reason, message string) *ApplicationError {
|
||||
return &ApplicationError{
|
||||
Status: Status{
|
||||
Code: int32(code),
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Newf New(code fmt.Sprintf(format, a...))
|
||||
func Newf(code int, reason, format string, a ...any) *ApplicationError {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Errorf returns an error object for the code, message and error info.
|
||||
func Errorf(code int, reason, format string, a ...any) error {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Code returns the http code for an error.
|
||||
// It supports wrapped errors.
|
||||
func Code(err error) int {
|
||||
if err == nil {
|
||||
return http.StatusOK
|
||||
}
|
||||
return int(FromError(err).Code)
|
||||
}
|
||||
|
||||
// Reason returns the reason for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Reason(err error) string {
|
||||
if err == nil {
|
||||
return UnknownReason
|
||||
}
|
||||
return FromError(err).Reason
|
||||
}
|
||||
|
||||
// Message returns the message for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Message(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return FromError(err).Message
|
||||
}
|
||||
|
||||
// Clone deep clone error to a new error.
|
||||
func Clone(err *ApplicationError) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var metadata map[string]string
|
||||
if err.Metadata != nil {
|
||||
metadata = make(map[string]string, len(err.Metadata))
|
||||
for k, v := range err.Metadata {
|
||||
metadata[k] = v
|
||||
}
|
||||
}
|
||||
return &ApplicationError{
|
||||
cause: err.cause,
|
||||
Status: Status{
|
||||
Code: err.Code,
|
||||
Reason: err.Reason,
|
||||
Message: err.Message,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FromError tries to convert an error to *ApplicationError.
|
||||
// It supports wrapped errors.
|
||||
func FromError(err error) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se
|
||||
}
|
||||
|
||||
// Fall back to a generic internal error.
|
||||
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
|
||||
}
|
||||
168
backend/internal/infrastructure/errors/errors_test.go
Normal file
168
backend/internal/infrastructure/errors/errors_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build unit
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplicationError_Basics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *ApplicationError
|
||||
want Status
|
||||
wantIs bool
|
||||
target error
|
||||
wrapped error
|
||||
}{
|
||||
{
|
||||
name: "new",
|
||||
err: New(400, "BAD_REQUEST", "invalid input"),
|
||||
want: Status{
|
||||
Code: 400,
|
||||
Reason: "BAD_REQUEST",
|
||||
Message: "invalid input",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is_matches_code_and_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "UNAUTHORIZED", "ignored message"),
|
||||
wantIs: true,
|
||||
},
|
||||
{
|
||||
name: "is_does_not_match_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "DIFFERENT", "ignored message"),
|
||||
wantIs: false,
|
||||
},
|
||||
{
|
||||
name: "from_error_unwraps_wrapped_application_error",
|
||||
err: New(404, "NOT_FOUND", "missing"),
|
||||
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
|
||||
want: Status{
|
||||
Code: 404,
|
||||
Reason: "NOT_FOUND",
|
||||
Message: "missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err != nil {
|
||||
require.Equal(t, tt.want, tt.err.Status)
|
||||
}
|
||||
|
||||
if tt.target != nil {
|
||||
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
|
||||
}
|
||||
|
||||
if tt.wrapped != nil {
|
||||
got := FromError(tt.wrapped)
|
||||
require.Equal(t, tt.want, got.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
md map[string]string
|
||||
}{
|
||||
{name: "non_nil", md: map[string]string{"a": "1"}},
|
||||
{name: "nil", md: nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
|
||||
|
||||
if tt.md == nil {
|
||||
require.Nil(t, appErr.Metadata)
|
||||
return
|
||||
}
|
||||
|
||||
tt.md["a"] = "changed"
|
||||
require.Equal(t, "1", appErr.Metadata["a"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromError_Generic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int32
|
||||
wantReason string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
err: stderrors.New("boom"),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
{
|
||||
name: "wrapped_plain_error",
|
||||
err: fmt.Errorf("wrap: %w", io.EOF),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FromError(tt.err)
|
||||
require.Equal(t, tt.wantCode, got.Code)
|
||||
require.Equal(t, tt.wantReason, got.Reason)
|
||||
require.Equal(t, tt.wantMsg, got.Message)
|
||||
require.Equal(t, tt.err, got.Unwrap())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantBody Status
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: Status{Code: int32(http.StatusOK)},
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: Forbidden("FORBIDDEN", "no access"),
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: Status{
|
||||
Code: int32(http.StatusForbidden),
|
||||
Reason: "FORBIDDEN",
|
||||
Message: "no access",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, body := ToHTTP(tt.err)
|
||||
require.Equal(t, tt.wantStatusCode, code)
|
||||
require.Equal(t, tt.wantBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
21
backend/internal/infrastructure/errors/http.go
Normal file
21
backend/internal/infrastructure/errors/http.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
|
||||
//
|
||||
// The returned body matches the project's Status shape:
|
||||
// { code, reason, message, metadata }.
|
||||
func ToHTTP(err error) (statusCode int, body Status) {
|
||||
if err == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
appErr := FromError(err)
|
||||
if appErr == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
}
|
||||
114
backend/internal/infrastructure/errors/types.go
Normal file
114
backend/internal/infrastructure/errors/types.go
Normal file
@@ -0,0 +1,114 @@
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// BadRequest new BadRequest error that is mapped to a 400 response.
|
||||
func BadRequest(reason, message string) *ApplicationError {
|
||||
return New(http.StatusBadRequest, reason, message)
|
||||
}
|
||||
|
||||
// IsBadRequest determines if err is an error which indicates a BadRequest error.
|
||||
// It supports wrapped errors.
|
||||
func IsBadRequest(err error) bool {
|
||||
return Code(err) == http.StatusBadRequest
|
||||
}
|
||||
|
||||
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
|
||||
func TooManyRequests(reason, message string) *ApplicationError {
|
||||
return New(http.StatusTooManyRequests, reason, message)
|
||||
}
|
||||
|
||||
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
|
||||
// It supports wrapped errors.
|
||||
func IsTooManyRequests(err error) bool {
|
||||
return Code(err) == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// Unauthorized new Unauthorized error that is mapped to a 401 response.
|
||||
func Unauthorized(reason, message string) *ApplicationError {
|
||||
return New(http.StatusUnauthorized, reason, message)
|
||||
}
|
||||
|
||||
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
|
||||
// It supports wrapped errors.
|
||||
func IsUnauthorized(err error) bool {
|
||||
return Code(err) == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// Forbidden new Forbidden error that is mapped to a 403 response.
|
||||
func Forbidden(reason, message string) *ApplicationError {
|
||||
return New(http.StatusForbidden, reason, message)
|
||||
}
|
||||
|
||||
// IsForbidden determines if err is an error which indicates a Forbidden error.
|
||||
// It supports wrapped errors.
|
||||
func IsForbidden(err error) bool {
|
||||
return Code(err) == http.StatusForbidden
|
||||
}
|
||||
|
||||
// NotFound new NotFound error that is mapped to a 404 response.
|
||||
func NotFound(reason, message string) *ApplicationError {
|
||||
return New(http.StatusNotFound, reason, message)
|
||||
}
|
||||
|
||||
// IsNotFound determines if err is an error which indicates an NotFound error.
|
||||
// It supports wrapped errors.
|
||||
func IsNotFound(err error) bool {
|
||||
return Code(err) == http.StatusNotFound
|
||||
}
|
||||
|
||||
// Conflict new Conflict error that is mapped to a 409 response.
|
||||
func Conflict(reason, message string) *ApplicationError {
|
||||
return New(http.StatusConflict, reason, message)
|
||||
}
|
||||
|
||||
// IsConflict determines if err is an error which indicates a Conflict error.
|
||||
// It supports wrapped errors.
|
||||
func IsConflict(err error) bool {
|
||||
return Code(err) == http.StatusConflict
|
||||
}
|
||||
|
||||
// InternalServer new InternalServer error that is mapped to a 500 response.
|
||||
func InternalServer(reason, message string) *ApplicationError {
|
||||
return New(http.StatusInternalServerError, reason, message)
|
||||
}
|
||||
|
||||
// IsInternalServer determines if err is an error which indicates an Internal error.
|
||||
// It supports wrapped errors.
|
||||
func IsInternalServer(err error) bool {
|
||||
return Code(err) == http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
|
||||
func ServiceUnavailable(reason, message string) *ApplicationError {
|
||||
return New(http.StatusServiceUnavailable, reason, message)
|
||||
}
|
||||
|
||||
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
|
||||
// It supports wrapped errors.
|
||||
func IsServiceUnavailable(err error) bool {
|
||||
return Code(err) == http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
|
||||
func GatewayTimeout(reason, message string) *ApplicationError {
|
||||
return New(http.StatusGatewayTimeout, reason, message)
|
||||
}
|
||||
|
||||
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
|
||||
// It supports wrapped errors.
|
||||
func IsGatewayTimeout(err error) bool {
|
||||
return Code(err) == http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
|
||||
func ClientClosed(reason, message string) *ApplicationError {
|
||||
return New(499, reason, message)
|
||||
}
|
||||
|
||||
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
|
||||
// It supports wrapped errors.
|
||||
func IsClientClosed(err error) bool {
|
||||
return Code(err) == 499
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
// JSONB 用于存储JSONB数据
|
||||
type JSONB map[string]interface{}
|
||||
type JSONB map[string]any
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
func (j *JSONB) Scan(value any) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
@@ -40,8 +40,8 @@ type Account struct {
|
||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
@@ -68,7 +68,8 @@ type Account struct {
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
Groups []*Group `gorm:"-" json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func (Account) TableName() string {
|
||||
@@ -145,7 +146,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
// 处理map[string]interface{}类型
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
@@ -163,7 +164,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
// 如果没有设置模型映射,则支持所有模型
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
return true // 没有映射配置,支持所有模型
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
@@ -174,7 +175,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// 如果没有映射,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
@@ -231,7 +232,7 @@ func (a *Account) GetCustomErrorCodes() []int {
|
||||
return nil
|
||||
}
|
||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
if arr, ok := raw.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
// JSON 数字默认解析为 float64
|
||||
@@ -277,3 +278,138 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// =============== OpenAI 相关方法 ===============
|
||||
|
||||
// IsOpenAI 检查是否为 OpenAI 平台账号
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
// IsAnthropic 检查是否为 Anthropic 平台账号
|
||||
func (a *Account) IsAnthropic() bool {
|
||||
return a.Platform == PlatformAnthropic
|
||||
}
|
||||
|
||||
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
|
||||
func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
}
|
||||
|
||||
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
|
||||
// 对于 API Key 类型账号,从 credentials 中获取 base_url
|
||||
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
}
|
||||
}
|
||||
return "https://api.openai.com" // OpenAI 默认 API URL
|
||||
}
|
||||
|
||||
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
|
||||
func (a *Account) GetOpenAIAccessToken() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("access_token")
|
||||
}
|
||||
|
||||
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
|
||||
func (a *Account) GetOpenAIRefreshToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("refresh_token")
|
||||
}
|
||||
|
||||
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
|
||||
func (a *Account) GetOpenAIIDToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
}
|
||||
|
||||
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
|
||||
// 返回空字符串表示透传原始 User-Agent
|
||||
func (a *Account) GetOpenAIUserAgent() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("user_agent")
|
||||
}
|
||||
|
||||
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTAccountID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_user_id")
|
||||
}
|
||||
|
||||
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
|
||||
func (a *Account) GetOpenAIOrganizationID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("organization_id")
|
||||
}
|
||||
|
||||
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
|
||||
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
expiresAtStr := a.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return nil
|
||||
}
|
||||
// 尝试解析时间
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
// 尝试解析为 Unix 时间戳
|
||||
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||
t = time.Unix(int64(v), 0)
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false // 没有过期时间信息,假设未过期
|
||||
}
|
||||
// 提前 60 秒认为过期,便于刷新
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
|
||||
@@ -13,13 +13,13 @@ const (
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
|
||||
// 订阅功能字段
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||
|
||||
@@ -9,15 +9,16 @@ import (
|
||||
type RedeemCode struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `gorm:"type:text" json:"notes"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 订阅类型专用字段
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||
|
||||
// 关联
|
||||
@@ -40,8 +41,10 @@ func (r *RedeemCode) CanUse() bool {
|
||||
}
|
||||
|
||||
// GenerateRedeemCode 生成唯一的兑换码
|
||||
func GenerateRedeemCode() string {
|
||||
func GenerateRedeemCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
@@ -19,17 +19,17 @@ func (Setting) TableName() string {
|
||||
// 设置Key常量
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -42,6 +42,7 @@ const (
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
@@ -80,6 +81,7 @@ type SystemSettings struct {
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -97,5 +99,6 @@ type PublicSettings struct {
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ type UsageLog struct {
|
||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
Username string `gorm:"size:100;default:''" json:"username"`
|
||||
Wechat string `gorm:"size:100;default:''" json:"wechat"`
|
||||
Notes string `gorm:"type:text;default:''" json:"notes"`
|
||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||
@@ -22,7 +25,8 @@ type User struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
|
||||
@@ -43,18 +43,25 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
42
backend/internal/pkg/openai/constants.go
Normal file
42
backend/internal/pkg/openai/constants.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
// Model represents an OpenAI model
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||
}
|
||||
|
||||
// DefaultModelIDs returns the default model ID list
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel default model for testing OpenAI accounts
|
||||
const DefaultTestModel = "gpt-5.1-codex"
|
||||
|
||||
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||
// Content loaded from instructions.txt at compile time
|
||||
//
|
||||
//go:embed instructions.txt
|
||||
var DefaultInstructions string
|
||||
118
backend/internal/pkg/openai/instructions.txt
Normal file
118
backend/internal/pkg/openai/instructions.txt
Normal file
@@ -0,0 +1,118 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No \"save/copy this file\" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no \"above/below\"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
|
||||
|
||||
366
backend/internal/pkg/openai/oauth.go
Normal file
366
backend/internal/pkg/openai/oauth.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
// Default redirect URI (can be customized)
|
||||
DefaultRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// Scopes
|
||||
DefaultScopes = "openid profile email offline_access"
|
||||
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
|
||||
RefreshScopes = "openid profile email"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
|
||||
// OpenAI uses hex encoding instead of base64url
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
// Uses base64url encoding as per RFC 7636
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OpenAI OAuth
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IDTokenClaims represents the claims from OpenAI ID Token
|
||||
type IDTokenClaims struct {
|
||||
// Standard claims
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Iss string `json:"iss"`
|
||||
Aud []string `json:"aud"` // OpenAI returns aud as an array
|
||||
Exp int64 `json:"exp"`
|
||||
Iat int64 `json:"iat"`
|
||||
|
||||
// OpenAI specific claims (nested under https://api.openai.com/auth)
|
||||
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIAuthClaims represents the OpenAI specific auth claims
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
|
||||
// OrganizationClaim represents an organization in the ID Token
|
||||
type OrganizationClaim struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Title string `json:"title"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request for OpenAI
|
||||
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
Scope: RefreshScopes,
|
||||
}
|
||||
}
|
||||
|
||||
// ToFormData converts TokenRequest to URL-encoded form data
|
||||
func (r *TokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("code", r.Code)
|
||||
params.Set("redirect_uri", r.RedirectURI)
|
||||
params.Set("code_verifier", r.CodeVerifier)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ToFormData converts RefreshTokenRequest to URL-encoded form data
|
||||
func (r *RefreshTokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("refresh_token", r.RefreshToken)
|
||||
params.Set("scope", r.Scope)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if necessary
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try standard encoding
|
||||
decoded, err = base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var claims IDTokenClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
}
|
||||
|
||||
// GetUserInfo extracts user info from ID Token claims
|
||||
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
info := &UserInfo{
|
||||
Email: c.Email,
|
||||
}
|
||||
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
// Get default organization ID
|
||||
for _, org := range c.OpenAIAuth.Organizations {
|
||||
if org.IsDefault {
|
||||
info.OrganizationID = org.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no default, use first org
|
||||
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
|
||||
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
18
backend/internal/pkg/openai/request.go
Normal file
18
backend/internal/pkg/openai/request.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package openai
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_vscode/",
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -4,27 +4,30 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items interface{} `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
Items any `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
func Success(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -33,7 +36,7 @@ func Success(c *gin.Context, data interface{}) {
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
func Created(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data interface{}) {
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: "",
|
||||
Metadata: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithDetails returns an error response compatible with the existing envelope while
|
||||
// optionally providing structured error fields (reason/metadata).
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
|
||||
// It returns true if an error was written.
|
||||
func ErrorFrom(c *gin.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
// BadRequest 返回400错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, message)
|
||||
@@ -75,7 +103,7 @@ func InternalError(c *gin.Context, message string) {
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
|
||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
@@ -99,7 +127,7 @@ type PaginationResult struct {
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
|
||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
|
||||
171
backend/internal/pkg/response/response_test.go
Normal file
171
backend/internal/pkg/response/response_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
//go:build unit
|
||||
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
reason string
|
||||
metadata map[string]string
|
||||
want Response
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "invalid request",
|
||||
want: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "structured_error",
|
||||
statusCode: http.StatusForbidden,
|
||||
message: "no access",
|
||||
reason: "FORBIDDEN",
|
||||
metadata: map[string]string{"k": "v"},
|
||||
want: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"k": "v"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFrom(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantWritten bool
|
||||
wantHTTPCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantWritten: false,
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"scope": "admin"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
Reason: "INVALID_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "unauthorized",
|
||||
Reason: "UNAUTHORIZED",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: infraerrors.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "not found",
|
||||
Reason: "NOT_FOUND",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: infraerrors.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
Code: http.StatusConflict,
|
||||
Message: "conflict",
|
||||
Reason: "CONFLICT",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown_error_defaults_to_500",
|
||||
err: errors.New("boom"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
written := ErrorFrom(c, tt.err)
|
||||
require.Equal(t, tt.wantWritten, written)
|
||||
|
||||
if !tt.wantWritten {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Empty(t, w.Body.String())
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
|
||||
|
||||
func TestTimeNowAffected(t *testing.T) {
|
||||
// Reset to UTC first
|
||||
Init("UTC")
|
||||
if err := Init("UTC"); err != nil {
|
||||
t.Fatalf("Init failed with UTC: %v", err)
|
||||
}
|
||||
utcNow := time.Now()
|
||||
|
||||
// Switch to Shanghai (UTC+8)
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
shanghaiNow := time.Now()
|
||||
|
||||
// The times should be the same instant, but different timezone representation
|
||||
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToday(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
today := Today()
|
||||
now := Now()
|
||||
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
// Create a time at 15:30:45
|
||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
|
||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||
// and why StartOfDay is more reliable for timezone-aware code
|
||||
|
||||
Init("Asia/Shanghai")
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
now := Now()
|
||||
|
||||
|
||||
209
backend/internal/pkg/usagestats/usage_log_types.go
Normal file
209
backend/internal/pkg/usagestats/usage_log_types.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats struct {
|
||||
// 用户统计
|
||||
TotalUsers int64 `json:"total_users"`
|
||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
||||
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
||||
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
||||
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 性能统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageSummary represents summary statistics for an account
|
||||
type AccountUsageSummary struct {
|
||||
Days int `json:"days"`
|
||||
ActualDaysUsed int `json:"actual_days_used"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||
Today *struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
} `json:"today"`
|
||||
HighestCostDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
} `json:"highest_cost_day"`
|
||||
HighestRequestDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
} `json:"highest_request_day"`
|
||||
}
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
}
|
||||
@@ -2,44 +2,68 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type AccountRepository struct {
|
||||
type accountRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAccountRepository(db *gorm.DB) *AccountRepository {
|
||||
return &AccountRepository{db: db}
|
||||
func NewAccountRepository(db *gorm.DB) service.AccountRepository {
|
||||
return &accountRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
|
||||
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Create(account).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
// 填充 GroupIDs 虚拟字段
|
||||
// 填充 GroupIDs 和 Groups 虚拟字段
|
||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
|
||||
for _, ag := range account.AccountGroups {
|
||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
account.Groups = append(account.Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||
if crsAccountID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Save(account).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 先删除账号与分组的绑定关系
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
@@ -48,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
var accounts []model.Account
|
||||
var total int64
|
||||
|
||||
@@ -78,15 +102,19 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 填充每个 Account 的 GroupIDs 虚拟字段
|
||||
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups)
|
||||
for i := range accounts {
|
||||
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
||||
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
|
||||
for _, ag := range accounts[i].AccountGroups {
|
||||
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
@@ -114,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
@@ -124,20 +152,20 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusError,
|
||||
"error_message": errorMsg,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
ag := &model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
@@ -146,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
return r.db.WithContext(ctx).Create(ag).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
|
||||
Delete(&model.AccountGroup{}).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
|
||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
|
||||
@@ -160,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
|
||||
return groups, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ? AND status = ?", platform, model.StatusActive).
|
||||
@@ -170,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
// 删除现有绑定
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
@@ -193,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
|
||||
// ListSchedulable 获取所有可调度的账号
|
||||
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
@@ -207,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupID 按组获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
@@ -222,26 +250,58 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByPlatform 按平台获取可调度的账号
|
||||
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ?", platform).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.platform = ?", platform).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// SetRateLimited 标记账号为限流状态(429)
|
||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// SetOverloaded 标记账号为过载状态(529)
|
||||
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("overload_until", until).Error
|
||||
}
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": nil,
|
||||
"rate_limit_reset_at": nil,
|
||||
"overload_until": nil,
|
||||
@@ -249,8 +309,8 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]interface{}{
|
||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]any{
|
||||
"session_window_status": status,
|
||||
}
|
||||
if start != nil {
|
||||
@@ -263,7 +323,79 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
}
|
||||
|
||||
// SetSchedulable 设置账号的调度开关
|
||||
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("schedulable", schedulable).Error
|
||||
}
|
||||
|
||||
// UpdateExtra updates specific fields in account's Extra JSONB field
|
||||
// It merges the updates into existing Extra data without overwriting other fields
|
||||
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current account to preserve existing Extra data
|
||||
var account model.Account
|
||||
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Extra if nil
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(model.JSONB)
|
||||
}
|
||||
|
||||
// Merge updates into existing Extra
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
|
||||
// Save updated Extra
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("extra", account.Extra).Error
|
||||
}
|
||||
|
||||
// BulkUpdate updates multiple accounts with the provided fields.
|
||||
// It merges credentials/extra JSONB fields instead of overwriting them.
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
updateMap := map[string]any{}
|
||||
|
||||
if updates.Name != nil {
|
||||
updateMap["name"] = *updates.Name
|
||||
}
|
||||
if updates.ProxyID != nil {
|
||||
updateMap["proxy_id"] = updates.ProxyID
|
||||
}
|
||||
if updates.Concurrency != nil {
|
||||
updateMap["concurrency"] = *updates.Concurrency
|
||||
}
|
||||
if updates.Priority != nil {
|
||||
updateMap["priority"] = *updates.Priority
|
||||
}
|
||||
if updates.Status != nil {
|
||||
updateMap["status"] = *updates.Status
|
||||
}
|
||||
if len(updates.Credentials) > 0 {
|
||||
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
|
||||
}
|
||||
if len(updates.Extra) > 0 {
|
||||
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
|
||||
}
|
||||
|
||||
if len(updateMap) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&model.Account{}).
|
||||
Where("id IN ?", ids).
|
||||
Clauses(clause.Returning{}).
|
||||
Updates(updateMap)
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
580
backend/internal/repository/account_repo_integration_test.go
Normal file
580
backend/internal/repository/account_repo_integration_test.go
Normal file
@@ -0,0 +1,580 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewAccountRepository(s.db).(*accountRepository)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *AccountRepoSuite) TestCreate() {
|
||||
account := &model.Account{
|
||||
Name: "test-create",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, account)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(account.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
|
||||
|
||||
account.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||
|
||||
var count int64
|
||||
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
|
||||
s.Require().Zero(count, "expected bindings to be removed")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *AccountRepoSuite) TestList() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
|
||||
|
||||
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(accounts, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(db *gorm.DB)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
wantCount int
|
||||
validate func(accounts []model.Account)
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
|
||||
},
|
||||
platform: model.PlatformOpenAI,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_type",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
|
||||
},
|
||||
accType: model.AccountTypeApiKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
|
||||
},
|
||||
status: model.StatusDisabled,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
|
||||
},
|
||||
search: "alpha",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Contains(accounts[0].Name, "alpha")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
db := testTx(s.T())
|
||||
repo := NewAccountRepository(db).(*accountRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(db)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
tt.validate(accounts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListByGroup() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
|
||||
|
||||
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListByGroup")
|
||||
s.Require().Len(accounts, 2)
|
||||
// Should be ordered by priority
|
||||
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListActive() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
|
||||
|
||||
accounts, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal("active1", accounts[0].Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListByPlatform")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// --- Preload and VirtualFields ---
|
||||
|
||||
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "acc1",
|
||||
ProxyID: &proxy.ID,
|
||||
})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotNil(got.Proxy, "expected Proxy preload")
|
||||
s.Require().Equal(proxy.ID, got.Proxy.ID)
|
||||
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
|
||||
s.Require().Equal(group.ID, got.GroupIDs[0])
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
|
||||
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
|
||||
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
|
||||
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
|
||||
}
|
||||
|
||||
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
|
||||
|
||||
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups")
|
||||
s.Require().Len(groups, 1, "expected 1 group")
|
||||
s.Require().Equal(g1.ID, groups[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after remove")
|
||||
s.Require().Empty(groups, "expected 0 groups after remove")
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after bind")
|
||||
s.Require().Len(groups, 2, "expected 2 groups after bind")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups, "expected 0 groups after binding empty list")
|
||||
}
|
||||
|
||||
// --- Schedulable ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
|
||||
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||
s.Require().NoError(err, "ListSchedulable")
|
||||
ids := idsOfAccounts(sched)
|
||||
s.Require().Contains(ids, okAcc.ID)
|
||||
s.Require().NotContains(ids, overloaded.ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
|
||||
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
|
||||
|
||||
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID")
|
||||
s.Require().Len(sched, 1, "expected only ok account schedulable")
|
||||
s.Require().Equal(okAcc.ID, sched[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
|
||||
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
|
||||
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(a1.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
|
||||
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.OverloadUntil)
|
||||
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
|
||||
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.RateLimitedAt)
|
||||
s.Require().NotNil(got.RateLimitResetAt)
|
||||
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
|
||||
until := time.Now().Add(1 * time.Hour)
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.RateLimitedAt)
|
||||
s.Require().Nil(got.RateLimitResetAt)
|
||||
s.Require().Nil(got.OverloadUntil)
|
||||
}
|
||||
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
|
||||
s.Require().Nil(account.LastUsedAt)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.LastUsedAt)
|
||||
}
|
||||
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.StatusError, got.Status)
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
|
||||
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.SessionWindowStart)
|
||||
s.Require().NotNil(got.SessionWindowEnd)
|
||||
s.Require().Equal("active", got.SessionWindowStatus)
|
||||
}
|
||||
|
||||
// --- UpdateExtra ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "acc-extra",
|
||||
Extra: model.JSONB{"a": "1"},
|
||||
})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("1", got.Extra["a"])
|
||||
s.Require().Equal("2", got.Extra["b"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
crsID := "crs-12345"
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "acc-crs",
|
||||
Extra: model.JSONB{"crs_account_id": crsID},
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got)
|
||||
s.Require().Equal("acc-crs", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
// --- BulkUpdate ---
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
|
||||
Priority: &newPriority,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
|
||||
s.Require().Equal(99, got1.Priority)
|
||||
s.Require().Equal(99, got2.Priority)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "bulk-cred",
|
||||
Credentials: model.JSONB{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Credentials: model.JSONB{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("value", got.Credentials["existing"])
|
||||
s.Require().Equal("new_value", got.Credentials["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "bulk-extra",
|
||||
Extra: model.JSONB{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Extra: model.JSONB{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("val", got.Extra["existing"])
|
||||
s.Require().Equal("new_val", got.Extra["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func idsOfAccounts(accounts []model.Account) []int64 {
|
||||
out := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, accounts[i].ID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -19,7 +18,7 @@ type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
125
backend/internal/repository/api_key_cache_integration_test.go
Normal file
125
backend/internal/repository/api_key_cache_integration_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "increment_increases_count_and_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||
require.Equal(s.T(), 2, count, "count mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "increment_increases_count",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||
|
||||
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||
require.NoError(s.T(), err, "Get dailyKey")
|
||||
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_expiry_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test-expiry"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||
require.NoError(s.T(), err, "TTL dailyKey")
|
||||
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
@@ -2,51 +2,55 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyRepository struct {
|
||||
type apiKeyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
|
||||
return &ApiKeyRepository{db: db}
|
||||
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Create(key).Error
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
|
||||
err := r.db.WithContext(ctx).Create(key).Error
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
var key model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
var keys []model.ApiKey
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
|
||||
@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Update("group_id", nil)
|
||||
@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
|
||||
355
backend/internal/repository/api_key_repo_integration_test.go
Normal file
355
backend/internal/repository/api_key_repo_integration_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||
|
||||
key := &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = model.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(model.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
})
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
|
||||
for i := 0; i < 5; i++ {
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-page-" + string(rune('a'+i)),
|
||||
Name: "Key",
|
||||
})
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
// User preloaded
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), count)
|
||||
}
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
|
||||
|
||||
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().Nil(got1.GroupID)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-1",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User)
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = model.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(model.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-2",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||
|
||||
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
@@ -8,8 +8,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -58,7 +57,7 @@ type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
@@ -90,7 +89,7 @@ func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
@@ -102,8 +101,8 @@ func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||
result := &ports.SubscriptionCacheData{}
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||||
result := &service.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
@@ -136,14 +135,14 @@ func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.Su
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
|
||||
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
@@ -0,0 +1,283 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
@@ -7,28 +7,37 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type claudeOAuthService struct{}
|
||||
|
||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||
return &claudeOAuthService{}
|
||||
return &claudeOAuthService{
|
||||
baseURL: "https://claude.ai",
|
||||
tokenURL: oauth.TokenURL,
|
||||
clientFactory: createReqClient,
|
||||
}
|
||||
}
|
||||
|
||||
type claudeOAuthService struct {
|
||||
baseURL string
|
||||
tokenURL string
|
||||
clientFactory func(proxyURL string) *req.Client
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
@@ -60,11 +69,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
@@ -132,30 +141,22 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if len(code) > 0 {
|
||||
parts := make([]string, 0, 2)
|
||||
for i, part := range []rune(code) {
|
||||
if part == '#' {
|
||||
authCode = code[:i]
|
||||
codeState = code[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
authCode = code
|
||||
}
|
||||
if idx := strings.Index(code, "#"); idx != -1 {
|
||||
authCode = code[:idx]
|
||||
codeState = code[idx+1:]
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
@@ -168,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
@@ -178,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
@@ -196,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
@@ -209,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
@@ -233,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client {
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func prefix(s string, n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
343
backend/internal/repository/claude_oauth_service_test.go
Normal file
343
backend/internal/repository/claude_oauth_service_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
formValues url.Values
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
errContain string
|
||||
wantUUID string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||
},
|
||||
wantUUID: "org-1",
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContain: "401",
|
||||
},
|
||||
{
|
||||
name: "invalid_json_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
if tt.errContain != "" {
|
||||
require.ErrorContains(s.T(), err, tt.errContain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantUUID, got)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantCode string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "parses_redirect_uri",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||
})
|
||||
},
|
||||
wantCode: "AUTH#STATE",
|
||||
validate: func(captured requestCapture) {
|
||||
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_code_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantCode, code)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "rt",
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_form",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
|
||||
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
|
||||
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
captured.formValues, _ = url.ParseQuery(string(captured.body))
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||
}
|
||||
@@ -9,17 +9,25 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUsageService struct{}
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{}
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get default transport")
|
||||
}
|
||||
transport = transport.Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
@@ -31,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
@@ -43,7 +51,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
105
backend/internal/repository/claude_usage_service_test.go
Normal file
105
backend/internal/repository/claude_usage_service_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeUsageServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
fetcher *claudeUsageService
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||
type usageRequestCapture struct {
|
||||
authorization string
|
||||
anthropicBeta string
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||
|
||||
// Assertions on captured request data
|
||||
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
require.ErrorContains(s.T(), err, "nope")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "decode response failed")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||
}
|
||||
@@ -5,60 +5,94 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
accountConcurrencyKeyPrefix = "concurrency:account:"
|
||||
userConcurrencyKeyPrefix = "concurrency:user:"
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
concurrencyTTL = 5 * time.Minute
|
||||
// Key prefixes for independent slot keys
|
||||
// Format: concurrency:account:{accountID}:{requestID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// Format: concurrency:user:{userID}:{requestID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// Wait queue keeps counter format: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
|
||||
// Slot TTL - each slot expires independently
|
||||
slotTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
|
||||
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
|
||||
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL in seconds
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
local pattern = KEYS[1]
|
||||
local slotKey = KEYS[2]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
|
||||
-- Count existing slots using SCAN
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
|
||||
-- Check if we can acquire a slot
|
||||
if count < maxConcurrency then
|
||||
redis.call('SET', slotKey, '1', 'EX', ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
// getCountScript counts slots using SCAN
|
||||
// KEYS[1] = pattern for SCAN
|
||||
getCountScript = redis.NewScript(`
|
||||
local pattern = KEYS[1]
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
return count
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
@@ -72,53 +106,90 @@ type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
|
||||
}
|
||||
|
||||
func accountSlotPattern(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
|
||||
}
|
||||
|
||||
func userSlotPattern(userID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := accountSlotPattern(accountID)
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
pattern := accountSlotPattern(accountID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := userSlotPattern(userID)
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
pattern := userSlotPattern(userID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -126,7 +197,7 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -16,24 +15,24 @@ type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||
func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data ports.VerificationCodeData
|
||||
var data service.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
||||
92
backend/internal/repository/email_cache_integration_test.go
Normal file
92
backend/internal/repository/email_cache_integration_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
40
backend/internal/repository/error_translate.go
Normal file
40
backend/internal/repository/error_translate.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return notFound.WithCause(err)
|
||||
}
|
||||
|
||||
if conflict != nil && isUniqueConstraintViolation(err) {
|
||||
return conflict.WithCause(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func isUniqueConstraintViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "duplicate key") ||
|
||||
strings.Contains(msg, "unique constraint") ||
|
||||
strings.Contains(msg, "duplicate entry")
|
||||
}
|
||||
172
backend/internal/repository/fixtures_integration_test.go
Normal file
172
backend/internal/repository/fixtures_integration_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
|
||||
t.Helper()
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = model.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = model.StatusActive
|
||||
}
|
||||
if u.CreatedAt.IsZero() {
|
||||
u.CreatedAt = time.Now()
|
||||
}
|
||||
if u.UpdatedAt.IsZero() {
|
||||
u.UpdatedAt = u.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(u).Error, "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
|
||||
t.Helper()
|
||||
if g.Platform == "" {
|
||||
g.Platform = model.PlatformAnthropic
|
||||
}
|
||||
if g.Status == "" {
|
||||
g.Status = model.StatusActive
|
||||
}
|
||||
if g.SubscriptionType == "" {
|
||||
g.SubscriptionType = model.SubscriptionTypeStandard
|
||||
}
|
||||
if g.CreatedAt.IsZero() {
|
||||
g.CreatedAt = time.Now()
|
||||
}
|
||||
if g.UpdatedAt.IsZero() {
|
||||
g.UpdatedAt = g.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(g).Error, "create group")
|
||||
return g
|
||||
}
|
||||
|
||||
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
|
||||
t.Helper()
|
||||
if p.Protocol == "" {
|
||||
p.Protocol = "http"
|
||||
}
|
||||
if p.Host == "" {
|
||||
p.Host = "127.0.0.1"
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = 8080
|
||||
}
|
||||
if p.Status == "" {
|
||||
p.Status = model.StatusActive
|
||||
}
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = time.Now()
|
||||
}
|
||||
if p.UpdatedAt.IsZero() {
|
||||
p.UpdatedAt = p.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(p).Error, "create proxy")
|
||||
return p
|
||||
}
|
||||
|
||||
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
|
||||
t.Helper()
|
||||
if a.Platform == "" {
|
||||
a.Platform = model.PlatformAnthropic
|
||||
}
|
||||
if a.Type == "" {
|
||||
a.Type = model.AccountTypeOAuth
|
||||
}
|
||||
if a.Status == "" {
|
||||
a.Status = model.StatusActive
|
||||
}
|
||||
if !a.Schedulable {
|
||||
a.Schedulable = true
|
||||
}
|
||||
if a.Credentials == nil {
|
||||
a.Credentials = model.JSONB{}
|
||||
}
|
||||
if a.Extra == nil {
|
||||
a.Extra = model.JSONB{}
|
||||
}
|
||||
if a.CreatedAt.IsZero() {
|
||||
a.CreatedAt = time.Now()
|
||||
}
|
||||
if a.UpdatedAt.IsZero() {
|
||||
a.UpdatedAt = a.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(a).Error, "create account")
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
|
||||
t.Helper()
|
||||
if k.Status == "" {
|
||||
k.Status = model.StatusActive
|
||||
}
|
||||
if k.CreatedAt.IsZero() {
|
||||
k.CreatedAt = time.Now()
|
||||
}
|
||||
if k.UpdatedAt.IsZero() {
|
||||
k.UpdatedAt = k.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(k).Error, "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
|
||||
t.Helper()
|
||||
if c.Status == "" {
|
||||
c.Status = model.StatusUnused
|
||||
}
|
||||
if c.Type == "" {
|
||||
c.Type = model.RedeemTypeBalance
|
||||
}
|
||||
if c.CreatedAt.IsZero() {
|
||||
c.CreatedAt = time.Now()
|
||||
}
|
||||
require.NoError(t, db.Create(c).Error, "create redeem code")
|
||||
return c
|
||||
}
|
||||
|
||||
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
|
||||
t.Helper()
|
||||
if s.Status == "" {
|
||||
s.Status = model.SubscriptionStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if s.StartsAt.IsZero() {
|
||||
s.StartsAt = now.Add(-1 * time.Hour)
|
||||
}
|
||||
if s.ExpiresAt.IsZero() {
|
||||
s.ExpiresAt = now.Add(24 * time.Hour)
|
||||
}
|
||||
if s.AssignedAt.IsZero() {
|
||||
s.AssignedAt = now
|
||||
}
|
||||
if s.CreatedAt.IsZero() {
|
||||
s.CreatedAt = now
|
||||
}
|
||||
if s.UpdatedAt.IsZero() {
|
||||
s.UpdatedAt = now
|
||||
}
|
||||
require.NoError(t, db.Create(s).Error, "create user subscription")
|
||||
return s
|
||||
}
|
||||
|
||||
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
|
||||
t.Helper()
|
||||
require.NoError(t, db.Create(&model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: priority,
|
||||
}).Error, "create account_group")
|
||||
}
|
||||
@@ -4,8 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GatewayCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GatewayCache
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGatewayCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
sessionID := "s1"
|
||||
accountID := int64(99)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
sessionID := "s2"
|
||||
accountID := int64(100)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
sessionID := "s3"
|
||||
accountID := int64(101)
|
||||
initialTTL := 1 * time.Minute
|
||||
refreshTTL := 3 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after Refresh")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
|
||||
// Set a non-integer value
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||
require.Error(s.T(), err, "expected error for corrupted value")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
|
||||
328
backend/internal/repository/github_release_service_test.go
Normal file
328
backend/internal/repository/github_release_service_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GitHubReleaseServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *githubReleaseClient
|
||||
tempDir string
|
||||
}
|
||||
|
||||
// testTransport redirects requests to the test server
|
||||
type testTransport struct {
|
||||
testServerURL string
|
||||
}
|
||||
|
||||
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the URL to point to our test server
|
||||
testURL := t.testServerURL + req.URL.Path
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||
s.tempDir = s.T().TempDir()
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "100")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized chunked download")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||
require.NoError(s.T(), err, "expected success")
|
||||
|
||||
b, err := os.ReadFile(dest)
|
||||
require.NoError(s.T(), err, "read")
|
||||
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
|
||||
require.Len(s.T(), b, 100, "downloaded content length mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for 404")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for 404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.Error(s.T(), err, "expected error for non-200")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "cancelled.bin")
|
||||
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
// Use a path that cannot be created (directory doesn't exist)
|
||||
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid destination path")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
releaseJSON := `{
|
||||
"tag_name": "v1.0.0",
|
||||
"name": "Release 1.0.0",
|
||||
"body": "Release notes",
|
||||
"html_url": "https://github.com/test/repo/releases/v1.0.0",
|
||||
"assets": [
|
||||
{
|
||||
"name": "app-linux-amd64.tar.gz",
|
||||
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(releaseJSON))
|
||||
}))
|
||||
|
||||
// Use custom transport to redirect requests to test server
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v1.0.0", release.TagName)
|
||||
require.Equal(s.T(), "Release 1.0.0", release.Name)
|
||||
require.Len(s.T(), release.Assets, 1)
|
||||
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
require.Contains(s.T(), err.Error(), "404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestGitHubReleaseServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(GitHubReleaseServiceSuite))
|
||||
}
|
||||
@@ -2,47 +2,52 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type GroupRepository struct {
|
||||
type groupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewGroupRepository(db *gorm.DB) *GroupRepository {
|
||||
return &GroupRepository{db: db}
|
||||
func NewGroupRepository(db *gorm.DB) service.GroupRepository {
|
||||
return &groupRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Create(group).Error
|
||||
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
|
||||
err := r.db.WithContext(ctx).Create(group).Error
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
var group model.Group
|
||||
err := r.db.WithContext(ctx).First(&group, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
|
||||
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Save(group).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
|
||||
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DB 返回底层数据库连接,用于事务处理
|
||||
func (r *GroupRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
group, err := r.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if group.IsSubscriptionType() {
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&model.UserSubscription{}).
|
||||
Where("group_id = ?", id).
|
||||
Select("user_id").
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, sub := range subscriptions {
|
||||
affectedUserIDs = append(affectedUserIDs, sub.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 删除订阅类型分组的订阅记录
|
||||
if group.IsSubscriptionType() {
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
|
||||
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. 从 users.allowed_groups 数组中移除该分组 ID
|
||||
if err := tx.Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", id).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 删除 account_groups 中间表的数据
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 5. 删除分组本身(带锁,避免并发写)
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
236
backend/internal/repository/group_repo_integration_test.go
Normal file
236
backend/internal/repository/group_repo_integration_test.go
Normal file
@@ -0,0 +1,236 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewGroupRepository(s.db).(*groupRepository)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(GroupRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *GroupRepoSuite) TestCreate() {
|
||||
group := &model.Group{
|
||||
Name: "test-create",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, group)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(group.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
|
||||
|
||||
group.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, group)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(groups, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(model.StatusDisabled, groups[0].Status)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||
Name: "g1",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||
Name: "g2",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
IsExclusive: true,
|
||||
})
|
||||
|
||||
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
|
||||
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
|
||||
}
|
||||
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal("active1", groups[0].Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
|
||||
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListActiveByPlatform")
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal("g1", groups[0].Name)
|
||||
}
|
||||
|
||||
// --- ExistsByName ---
|
||||
|
||||
func (s *GroupRepoSuite) TestExistsByName() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
|
||||
|
||||
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||
s.Require().NoError(err, "ExistsByName")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- GetAccountCount ---
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), affected)
|
||||
|
||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
@@ -5,16 +5,19 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUpstreamService struct {
|
||||
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
|
||||
type httpUpstreamService struct {
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
@@ -27,13 +30,13 @@ func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &claudeUpstreamService{
|
||||
return &httpUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
@@ -41,7 +44,7 @@ func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Re
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
115
backend/internal/repository/http_upstream_test.go
Normal file
115
backend/internal/repository/http_upstream_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type HTTPUpstreamSuite struct {
|
||||
suite.Suite
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||
s.cfg = &config.Config{}
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
transport, ok := svc.defaultClient.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
transport, ok := svc.defaultClient.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
|
||||
got := svc.createProxyClient("://bad-proxy-url")
|
||||
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "")
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct", string(b), "unexpected body")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
seen := make(chan string, 1)
|
||||
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
s.T().Cleanup(proxySrv.Close)
|
||||
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, proxySrv.URL)
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "proxied", string(b), "unexpected body")
|
||||
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct-empty")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "")
|
||||
require.NoError(s.T(), err, "Do with empty proxy")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct-empty", string(b))
|
||||
}
|
||||
|
||||
func TestHTTPUpstreamSuite(t *testing.T) {
|
||||
suite.Run(t, new(HTTPUpstreamSuite))
|
||||
}
|
||||
@@ -6,8 +6,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -20,24 +19,24 @@ type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||
func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp ports.Fingerprint
|
||||
var fp service.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type IdentityCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *identityCache
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewIdentityCache(s.rdb).(*identityCache)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.NoError(s.T(), err, "GetFingerprint")
|
||||
require.Equal(s.T(), "c1", gotFP.ClientID)
|
||||
require.Equal(s.T(), "ua", gotFP.UserAgent)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
|
||||
require.NoError(s.T(), err, "TTL fpKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 999)
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
|
||||
err := s.cache.SetFingerprint(s.ctx, 100, nil)
|
||||
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
|
||||
}
|
||||
|
||||
func TestIdentityCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(IdentityCacheSuite))
|
||||
}
|
||||
369
backend/internal/repository/integration_harness_test.go
Normal file
369
backend/internal/repository/integration_harness_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
redisclient "github.com/redis/go-redis/v9"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
gormpostgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
redisImageTag = "redis:8.4-alpine"
|
||||
postgresImageTag = "postgres:18.1-alpine3.23"
|
||||
)
|
||||
|
||||
var (
|
||||
integrationDB *gorm.DB
|
||||
integrationRedis *redisclient.Client
|
||||
|
||||
redisNamespaceSeq uint64
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := timezone.Init("UTC"); err != nil {
|
||||
log.Printf("failed to init timezone: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !dockerIsAvailable(ctx) {
|
||||
// In CI we expect Docker to be available so integration tests should fail loudly.
|
||||
if os.Getenv("CI") != "" {
|
||||
log.Printf("docker is not available (CI=true); failing integration tests")
|
||||
os.Exit(1)
|
||||
}
|
||||
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
postgresImage := selectDockerImage(ctx, postgresImageTag)
|
||||
pgContainer, err := tcpostgres.Run(
|
||||
ctx,
|
||||
postgresImage,
|
||||
tcpostgres.WithDatabase("sub2api_test"),
|
||||
tcpostgres.WithUsername("postgres"),
|
||||
tcpostgres.WithPassword("postgres"),
|
||||
tcpostgres.BasicWaitStrategies(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start postgres container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = pgContainer.Terminate(ctx) }()
|
||||
|
||||
redisContainer, err := tcredis.Run(
|
||||
ctx,
|
||||
redisImageTag,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start redis container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = redisContainer.Terminate(ctx) }()
|
||||
|
||||
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
|
||||
if err != nil {
|
||||
log.Printf("failed to get postgres dsn: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("failed to open gorm db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := model.AutoMigrate(integrationDB); err != nil {
|
||||
log.Printf("failed to automigrate db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis host: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis port: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationRedis = redisclient.NewClient(&redisclient.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
if err := integrationRedis.Ping(ctx).Err(); err != nil {
|
||||
log.Printf("failed to ping redis: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
_ = integrationRedis.Close()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func dockerIsAvailable(ctx context.Context) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "info")
|
||||
cmd.Env = os.Environ()
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func selectDockerImage(ctx context.Context, preferred string) string {
|
||||
if dockerImageExists(ctx, preferred) {
|
||||
return preferred
|
||||
}
|
||||
|
||||
return preferred
|
||||
}
|
||||
|
||||
func dockerImageExists(ctx context.Context, image string) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
|
||||
}
|
||||
|
||||
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
|
||||
pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return db.PingContext(pingCtx)
|
||||
}
|
||||
|
||||
func testTx(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
tx := integrationDB.Begin()
|
||||
require.NoError(t, tx.Error, "begin tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback().Error
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
func testRedis(t *testing.T) *redisclient.Client {
|
||||
t.Helper()
|
||||
|
||||
prefix := fmt.Sprintf(
|
||||
"it:%s:%d:%d:",
|
||||
sanitizeRedisNamespace(t.Name()),
|
||||
time.Now().UnixNano(),
|
||||
atomic.AddUint64(&redisNamespaceSeq, 1),
|
||||
)
|
||||
|
||||
opts := *integrationRedis.Options()
|
||||
rdb := redisclient.NewClient(&opts)
|
||||
rdb.AddHook(prefixHook{prefix: prefix})
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx := context.Background()
|
||||
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
|
||||
require.NoError(t, err, "scan redis keys for cleanup")
|
||||
if len(keys) > 0 {
|
||||
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
|
||||
t.Helper()
|
||||
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
|
||||
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
|
||||
}
|
||||
|
||||
func sanitizeRedisNamespace(name string) string {
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, " ", "_")
|
||||
return name
|
||||
}
|
||||
|
||||
type prefixHook struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
|
||||
|
||||
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
|
||||
return func(ctx context.Context, cmd redisclient.Cmder) error {
|
||||
h.prefixCmd(cmd)
|
||||
return next(ctx, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redisclient.Cmder) error {
|
||||
for _, cmd := range cmds {
|
||||
h.prefixCmd(cmd)
|
||||
}
|
||||
return next(ctx, cmds)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
|
||||
args := cmd.Args()
|
||||
if len(args) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
prefixOne := func(i int) {
|
||||
if i < 0 || i >= len(args) {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := args[i].(type) {
|
||||
case string:
|
||||
if v != "" && !strings.HasPrefix(v, h.prefix) {
|
||||
args[i] = h.prefix + v
|
||||
}
|
||||
case []byte:
|
||||
s := string(v)
|
||||
if s != "" && !strings.HasPrefix(s, h.prefix) {
|
||||
args[i] = []byte(h.prefix + s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(cmd.Name()) {
|
||||
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
|
||||
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
|
||||
prefixOne(1)
|
||||
case "del", "unlink":
|
||||
for i := 1; i < len(args); i++ {
|
||||
prefixOne(i)
|
||||
}
|
||||
case "eval", "evalsha", "eval_ro", "evalsha_ro":
|
||||
if len(args) < 3 {
|
||||
return
|
||||
}
|
||||
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
|
||||
if err != nil || numKeys <= 0 {
|
||||
return
|
||||
}
|
||||
for i := 0; i < numKeys && 3+i < len(args); i++ {
|
||||
prefixOne(3 + i)
|
||||
}
|
||||
case "scan":
|
||||
for i := 2; i+1 < len(args); i++ {
|
||||
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
|
||||
prefixOne(i + 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationRedisSuite provides a base suite for Redis integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and rdb.
|
||||
type IntegrationRedisSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
rdb *redisclient.Client
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and rdb for each test method.
|
||||
func (s *IntegrationRedisSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.rdb = testRedis(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
|
||||
// AssertTTLWithin asserts that ttl is within [min, max].
|
||||
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
|
||||
s.T().Helper()
|
||||
assertTTLWithin(s.T(), ttl, min, max)
|
||||
}
|
||||
|
||||
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and db.
|
||||
type IntegrationDBSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and db for each test method.
|
||||
func (s *IntegrationDBSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
93
backend/internal/repository/openai_oauth_service.go
Normal file
93
backend/internal/repository/openai_oauth_service.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() service.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{tokenURL: openai.TokenURL}
|
||||
}
|
||||
|
||||
type openaiOAuthService struct {
|
||||
tokenURL string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("code", code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type OpenAIOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
svc *openaiOAuthService
|
||||
received chan url.Values
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.received = make(chan url.Values, 1)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
errCh <- "method mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code"); got != "code" {
|
||||
errCh <- "code mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code_verifier"); got != "ver" {
|
||||
errCh <- "code_verifier mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("refresh_token"); got != "rt" {
|
||||
errCh <- "refresh_token mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
|
||||
errCh <- "scope mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at2", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "bad")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 400")
|
||||
require.ErrorContains(s.T(), err, "bad")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "request failed")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
want := "http://localhost:9999/cb"
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if got := r.PostForm.Get("redirect_uri"); got != want {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
s.received <- r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
s.svc.tokenURL = s.srv.URL + "?x=1"
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case <-s.received:
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "unauthorized")
|
||||
}))
|
||||
|
||||
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.Error(s.T(), err, "expected error for non-2xx status")
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type pricingRemoteClient struct {
|
||||
@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
|
||||
147
backend/internal/repository/pricing_service_test.go
Normal file
147
backend/internal/repository/pricing_service_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PricingServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
client *pricingRemoteClient
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ok" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
|
||||
require.NoError(s.T(), err, "FetchPricingJSON")
|
||||
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/hashfile":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
|
||||
case "/hashonly":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("def456\n"))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "abc123", hash, "hash mismatch")
|
||||
|
||||
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "def456", hash2, "hash mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// empty body
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
|
||||
require.NoError(s.T(), err, "FetchHashText empty body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(" \n"))
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
|
||||
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestPricingServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(PricingServiceSuite))
|
||||
}
|
||||
@@ -11,15 +11,19 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type proxyProbeService struct{}
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
return &proxyProbeService{}
|
||||
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
|
||||
}
|
||||
|
||||
const defaultIPInfoURL = "https://ipinfo.io/json"
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -43,7 +47,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
|
||||
121
backend/internal/repository/proxy_probe_service_test.go
Normal file
121
backend/internal/repository/proxy_probe_service_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ProxyProbeServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
proxySrv *httptest.Server
|
||||
prober *proxyProbeService
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
if s.proxySrv != nil {
|
||||
s.proxySrv.Close()
|
||||
s.proxySrv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
|
||||
_, err := createProxyTransport("://bad")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "invalid proxy URL")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
|
||||
_, err := createProxyTransport("ftp://127.0.0.1:1")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
|
||||
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
|
||||
require.NoError(s.T(), err, "createProxyTransport")
|
||||
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
seen := make(chan string, 1)
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "c", info.City)
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
|
||||
// Verify proxy received the request
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status: 503")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to parse response")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
||||
s.prober.ipInfoURL = "://invalid-url"
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.proxySrv.Close()
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||
}
|
||||
|
||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||
}
|
||||
@@ -2,47 +2,50 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProxyRepository struct {
|
||||
type proxyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewProxyRepository(db *gorm.DB) *ProxyRepository {
|
||||
return &ProxyRepository{db: db}
|
||||
func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
|
||||
return &proxyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Create(proxy).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
var proxy model.Proxy
|
||||
err := r.db.WithContext(ctx).First(&proxy, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
|
||||
}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Save(proxy).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []model.Proxy
|
||||
var total int64
|
||||
|
||||
@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
|
||||
return proxies, err
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
|
||||
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
|
||||
@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
||||
}
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("proxy_id = ?", proxyID).
|
||||
@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
type result struct {
|
||||
ProxyID int64 `gorm:"column:proxy_id"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
|
||||
302
backend/internal/repository/proxy_repo_integration_test.go
Normal file
302
backend/internal/repository/proxy_repo_integration_test.go
Normal file
@@ -0,0 +1,302 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *proxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewProxyRepository(s.db).(*proxyRepository)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCreate() {
|
||||
proxy := &model.Proxy{
|
||||
Name: "test-create",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(proxy.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestUpdate() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
|
||||
|
||||
proxy.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestDelete() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestList() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||
|
||||
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(proxies, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("socks5", proxies[0].Protocol)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Contains(proxies[0].Name, "production")
|
||||
}
|
||||
|
||||
// --- ListActive ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActive() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
|
||||
|
||||
proxies, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("active1", proxies[0].Name)
|
||||
}
|
||||
|
||||
// --- ExistsByHostPortAuth ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p-noauth",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(exists)
|
||||
}
|
||||
|
||||
// --- CountAccountsByProxyID ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- GetAccountCountsForProxies ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(counts)
|
||||
}
|
||||
|
||||
// --- ListActiveWithAccountCount ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p1",
|
||||
Status: model.StatusActive,
|
||||
CreatedAt: base.Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p2",
|
||||
Status: model.StatusActive,
|
||||
CreatedAt: base,
|
||||
})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p3-inactive",
|
||||
Status: model.StatusDisabled,
|
||||
})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 active proxies")
|
||||
|
||||
// Sorted by created_at DESC, so p2 first
|
||||
s.Require().Equal(p2.ID, withCounts[0].ID)
|
||||
s.Require().Equal(int64(1), withCounts[0].AccountCount)
|
||||
s.Require().Equal(p1.ID, withCounts[1].ID)
|
||||
s.Require().Equal(int64(2), withCounts[1].AccountCount)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
CreatedAt: time.Now().Add(-1 * time.Hour),
|
||||
UpdatedAt: time.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p2",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists, "expected proxy to exist")
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 proxies")
|
||||
for _, pc := range withCounts {
|
||||
switch pc.ID {
|
||||
case p1.ID:
|
||||
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
|
||||
case p2.ID:
|
||||
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
|
||||
default:
|
||||
s.Require().Fail("unexpected proxy id", pc.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -20,7 +19,7 @@ type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||
func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user