mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 00:40:22 +08:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f57f12c6cc | ||
|
|
5fca2d10b9 | ||
|
|
8fbe1ad70d | ||
|
|
25a304c231 | ||
|
|
9d30ceae8d | ||
|
|
60f6ed6bf6 | ||
|
|
4a2f7d4a99 | ||
|
|
c19a393be9 | ||
|
|
938ffb002e | ||
|
|
372a01290b | ||
|
|
8b163ca49b | ||
|
|
d23810dc53 | ||
|
|
62ed5422dd | ||
|
|
2e76302af7 | ||
|
|
6553828008 | ||
|
|
adcb7bf00e |
7
.github/workflows/backend-ci.yml
vendored
7
.github/workflows/backend-ci.yml
vendored
@@ -17,9 +17,12 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: Run tests
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: go test ./...
|
||||
run: make test-unit
|
||||
- name: Integration tests
|
||||
working-directory: backend
|
||||
run: make test-integration
|
||||
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
94
.github/workflows/release.yml
vendored
94
.github/workflows/release.yml
vendored
@@ -85,6 +85,19 @@ jobs:
|
||||
go-version: '1.24'
|
||||
cache-dependency-path: backend/go.sum
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Fetch tags with annotations
|
||||
run: |
|
||||
# 确保获取完整的 annotated tag 信息
|
||||
@@ -117,87 +130,16 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
|
||||
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
# ===========================================================================
|
||||
# Docker Build and Push
|
||||
# ===========================================================================
|
||||
docker:
|
||||
needs: [update-version, build-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download VERSION artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: version-file
|
||||
path: backend/cmd/server/
|
||||
|
||||
- name: Download frontend artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: backend/internal/web/dist/
|
||||
|
||||
# Extract version from tag
|
||||
- name: Extract version
|
||||
id: version
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Version: $VERSION"
|
||||
|
||||
# Set up Docker Buildx for multi-platform builds
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Login to DockerHub
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Extract metadata for Docker
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
weishaw/sub2api
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
# Build and push Docker image
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.version }}
|
||||
COMMIT=${{ github.sha }}
|
||||
DATE=${{ github.event.head_commit.timestamp }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# Update DockerHub description (optional)
|
||||
# Update DockerHub description
|
||||
- name: Update DockerHub description
|
||||
uses: peter-evans/dockerhub-description@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
repository: weishaw/sub2api
|
||||
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
|
||||
short-description: "Sub2API - AI API Gateway Platform"
|
||||
readme-filepath: ./deploy/DOCKER.md
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -92,6 +92,13 @@ backend/internal/web/dist/*
|
||||
# 后端运行时缓存数据
|
||||
backend/data/
|
||||
|
||||
# ===================
|
||||
# 本地配置文件(包含敏感信息)
|
||||
# ===================
|
||||
backend/config.yaml
|
||||
deploy/config.yaml
|
||||
backend/.installed
|
||||
|
||||
# ===================
|
||||
# 其他
|
||||
# ===================
|
||||
|
||||
@@ -52,10 +52,58 @@ changelog:
|
||||
# 禁用自动 changelog,完全使用 tag 消息
|
||||
disable: true
|
||||
|
||||
# Docker images
|
||||
dockers:
|
||||
- id: amd64
|
||||
goos: linux
|
||||
goarch: amd64
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
|
||||
- id: arm64
|
||||
goos: linux
|
||||
goarch: arm64
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
|
||||
# Docker manifests for multi-arch support
|
||||
docker_manifests:
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: Wei-Shaw
|
||||
name: sub2api
|
||||
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
|
||||
name: "{{ .Env.GITHUB_REPO_NAME }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "Sub2API {{.Version}}"
|
||||
@@ -73,7 +121,7 @@ release:
|
||||
|
||||
**One-line install (Linux):**
|
||||
```bash
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
**Manual download:**
|
||||
@@ -81,5 +129,5 @@ release:
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
|
||||
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
|
||||
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
|
||||
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)
|
||||
|
||||
40
Dockerfile.goreleaser
Normal file
40
Dockerfile.goreleaser
Normal file
@@ -0,0 +1,40 @@
|
||||
# =============================================================================
|
||||
# Sub2API Dockerfile for GoReleaser
|
||||
# =============================================================================
|
||||
# This Dockerfile is used by GoReleaser to build Docker images.
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy pre-built binary from GoReleaser
|
||||
COPY sub2api /app/sub2api
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
USER sub2api
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: wire build build-embed
|
||||
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage
|
||||
|
||||
wire:
|
||||
@echo "生成 Wire 代码..."
|
||||
@@ -13,4 +13,21 @@ build:
|
||||
build-embed:
|
||||
@echo "构建后端(嵌入前端)..."
|
||||
@go build -tags embed -o bin/server ./cmd/server
|
||||
@echo "构建完成: bin/server (with embedded frontend)"
|
||||
@echo "构建完成: bin/server (with embedded frontend)"
|
||||
|
||||
test-unit:
|
||||
@go test ./... $(TEST_ARGS)
|
||||
|
||||
test-integration:
|
||||
@go test -tags integration ./internal/repository -count=1 -race -parallel=8
|
||||
|
||||
test-cover-integration:
|
||||
@echo "运行集成测试并生成覆盖率报告..."
|
||||
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./internal/repository/...
|
||||
@go tool cover -func=coverage.out | tail -1
|
||||
@go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "覆盖率报告已生成: coverage.html"
|
||||
|
||||
clean-coverage:
|
||||
@rm -f coverage.out coverage.html
|
||||
@echo "覆盖率文件已清理"
|
||||
@@ -86,7 +86,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "debug" # debug/release
|
||||
|
||||
database:
|
||||
host: "127.0.0.1"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "XZeRr7nkjHWhm8fw"
|
||||
dbname: "sub2api"
|
||||
sslmode: "disable"
|
||||
|
||||
redis:
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
jwt:
|
||||
secret: "your-secret-key-change-in-production"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
admin_email: "admin@sub2api.com"
|
||||
admin_password: "admin123"
|
||||
user_concurrency: 5
|
||||
user_balance: 0
|
||||
api_key_prefix: "sk-"
|
||||
rate_multiplier: 1.0
|
||||
|
||||
# Timezone configuration (similar to PHP's date_default_timezone_set)
|
||||
# This affects ALL time operations:
|
||||
# - Database timestamps
|
||||
# - Usage statistics "today" boundary
|
||||
# - Subscription expiry times
|
||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
||||
timezone: "Asia/Shanghai"
|
||||
@@ -8,10 +8,16 @@ require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/imroc/req/v3 v3.56.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/spf13/viper v1.18.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/term v0.37.0
|
||||
@@ -21,51 +27,99 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/google/wire v0.7.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.4.3 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.4 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.1 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||
github.com/moby/sys/sequential v0.6.0 // indirect
|
||||
github.com/moby/sys/user v0.4.0 // indirect
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.56.0 // indirect
|
||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
@@ -75,6 +129,7 @@ require (
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
169
backend/go.sum
169
backend/go.sum
@@ -1,3 +1,11 @@
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@@ -7,17 +15,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
@@ -28,6 +62,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -40,9 +81,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
@@ -54,6 +94,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
@@ -64,8 +106,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
||||
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
|
||||
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -85,36 +129,70 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
||||
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
||||
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
|
||||
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
@@ -128,6 +206,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -135,16 +215,55 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
@@ -164,8 +283,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
@@ -177,9 +302,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
|
||||
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
|
||||
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
@@ -192,6 +323,6 @@ gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -34,10 +34,20 @@ type AccountHandler struct {
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, openaiOAuthService *service.OpenAIOAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService, concurrencyService *service.ConcurrencyService) *AccountHandler {
|
||||
func NewAccountHandler(
|
||||
adminService service.AdminService,
|
||||
oauthService *service.OAuthService,
|
||||
openaiOAuthService *service.OpenAIOAuthService,
|
||||
rateLimitService *service.RateLimitService,
|
||||
accountUsageService *service.AccountUsageService,
|
||||
accountTestService *service.AccountTestService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
@@ -46,6 +56,7 @@ func NewAccountHandler(adminService service.AdminService, oauthService *service.
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +87,19 @@ type UpdateAccountRequest struct {
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*model.Account
|
||||
@@ -224,6 +248,13 @@ type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
SyncProxies *bool `json:"sync_proxies"`
|
||||
}
|
||||
|
||||
// Test handles testing account connectivity with SSE streaming
|
||||
// POST /api/v1/admin/accounts/:id/test
|
||||
func (h *AccountHandler) Test(c *gin.Context) {
|
||||
@@ -244,6 +275,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
// POST /api/v1/admin/accounts/sync/crs
|
||||
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||
var req SyncFromCRSRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Default to syncing proxies (can be disabled by explicitly setting false)
|
||||
syncProxies := true
|
||||
if req.SyncProxies != nil {
|
||||
syncProxies = *req.SyncProxies
|
||||
}
|
||||
|
||||
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||
BaseURL: req.BaseURL,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
SyncProxies: syncProxies,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Sync failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
@@ -387,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUpdateCredentialsRequest represents batch credentials update request
|
||||
type BatchUpdateCredentialsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
// BatchUpdateCredentials handles batch updating credentials fields
|
||||
// POST /api/v1/admin/accounts/batch-update-credentials
|
||||
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
var req BatchUpdateCredentialsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate value type based on field
|
||||
if req.Field == "intercept_warmup_requests" {
|
||||
// Must be boolean
|
||||
if _, ok := req.Value.(bool); !ok {
|
||||
response.BadRequest(c, "intercept_warmup_requests must be boolean")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// account_uuid and org_uuid can be string or null
|
||||
if req.Value != nil {
|
||||
if _, ok := req.Value.(string); !ok {
|
||||
response.BadRequest(c, req.Field+" must be string or null")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
|
||||
// Update account
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
|
||||
// POST /api/v1/admin/accounts/bulk-update
|
||||
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
var req BulkUpdateAccountsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hasUpdates := req.Name != "" ||
|
||||
req.ProxyID != nil ||
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.Status != "" ||
|
||||
req.GroupIDs != nil ||
|
||||
len(req.Credentials) > 0 ||
|
||||
len(req.Extra) > 0
|
||||
|
||||
if !hasUpdates {
|
||||
response.BadRequest(c, "No updates provided")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
|
||||
@@ -2,11 +2,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type AccountRepository struct {
|
||||
@@ -39,6 +42,22 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||
if crsAccountID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Save(account).Error
|
||||
}
|
||||
@@ -335,3 +354,47 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("extra", account.Extra).Error
|
||||
}
|
||||
|
||||
// BulkUpdate updates multiple accounts with the provided fields.
|
||||
// It merges credentials/extra JSONB fields instead of overwriting them.
|
||||
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
updateMap := map[string]any{}
|
||||
|
||||
if updates.Name != nil {
|
||||
updateMap["name"] = *updates.Name
|
||||
}
|
||||
if updates.ProxyID != nil {
|
||||
updateMap["proxy_id"] = updates.ProxyID
|
||||
}
|
||||
if updates.Concurrency != nil {
|
||||
updateMap["concurrency"] = *updates.Concurrency
|
||||
}
|
||||
if updates.Priority != nil {
|
||||
updateMap["priority"] = *updates.Priority
|
||||
}
|
||||
if updates.Status != nil {
|
||||
updateMap["status"] = *updates.Status
|
||||
}
|
||||
if len(updates.Credentials) > 0 {
|
||||
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
|
||||
}
|
||||
if len(updates.Extra) > 0 {
|
||||
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
|
||||
}
|
||||
|
||||
if len(updateMap) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&model.Account{}).
|
||||
Where("id IN ?", ids).
|
||||
Clauses(clause.Returning{}).
|
||||
Updates(updateMap)
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
580
backend/internal/repository/account_repo_integration_test.go
Normal file
580
backend/internal/repository/account_repo_integration_test.go
Normal file
@@ -0,0 +1,580 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *AccountRepository
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewAccountRepository(s.db)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *AccountRepoSuite) TestCreate() {
|
||||
account := &model.Account{
|
||||
Name: "test-create",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, account)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(account.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
|
||||
|
||||
account.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||
|
||||
var count int64
|
||||
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
|
||||
s.Require().Zero(count, "expected bindings to be removed")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *AccountRepoSuite) TestList() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
|
||||
|
||||
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(accounts, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(db *gorm.DB)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
wantCount int
|
||||
validate func(accounts []model.Account)
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
|
||||
},
|
||||
platform: model.PlatformOpenAI,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_type",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
|
||||
},
|
||||
accType: model.AccountTypeApiKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
|
||||
},
|
||||
status: model.StatusDisabled,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
|
||||
},
|
||||
search: "alpha",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Contains(accounts[0].Name, "alpha")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
db := testTx(s.T())
|
||||
repo := NewAccountRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(db)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
tt.validate(accounts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListByGroup() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
|
||||
|
||||
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListByGroup")
|
||||
s.Require().Len(accounts, 2)
|
||||
// Should be ordered by priority
|
||||
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListActive() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
|
||||
|
||||
accounts, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal("active1", accounts[0].Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListByPlatform")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// --- Preload and VirtualFields ---
|
||||
|
||||
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "acc1",
|
||||
ProxyID: &proxy.ID,
|
||||
})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotNil(got.Proxy, "expected Proxy preload")
|
||||
s.Require().Equal(proxy.ID, got.Proxy.ID)
|
||||
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
|
||||
s.Require().Equal(group.ID, got.GroupIDs[0])
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
|
||||
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
|
||||
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
|
||||
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
|
||||
}
|
||||
|
||||
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
|
||||
|
||||
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups")
|
||||
s.Require().Len(groups, 1, "expected 1 group")
|
||||
s.Require().Equal(g1.ID, groups[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after remove")
|
||||
s.Require().Empty(groups, "expected 0 groups after remove")
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after bind")
|
||||
s.Require().Len(groups, 2, "expected 2 groups after bind")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups, "expected 0 groups after binding empty list")
|
||||
}
|
||||
|
||||
// --- Schedulable ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
|
||||
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||
s.Require().NoError(err, "ListSchedulable")
|
||||
ids := idsOfAccounts(sched)
|
||||
s.Require().Contains(ids, okAcc.ID)
|
||||
s.Require().NotContains(ids, overloaded.ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
|
||||
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
|
||||
|
||||
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID")
|
||||
s.Require().Len(sched, 1, "expected only ok account schedulable")
|
||||
s.Require().Equal(okAcc.ID, sched[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
|
||||
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
|
||||
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(a1.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
|
||||
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.OverloadUntil)
|
||||
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
|
||||
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.RateLimitedAt)
|
||||
s.Require().NotNil(got.RateLimitResetAt)
|
||||
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
|
||||
until := time.Now().Add(1 * time.Hour)
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.RateLimitedAt)
|
||||
s.Require().Nil(got.RateLimitResetAt)
|
||||
s.Require().Nil(got.OverloadUntil)
|
||||
}
|
||||
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
|
||||
s.Require().Nil(account.LastUsedAt)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.LastUsedAt)
|
||||
}
|
||||
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.StatusError, got.Status)
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
|
||||
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.SessionWindowStart)
|
||||
s.Require().NotNil(got.SessionWindowEnd)
|
||||
s.Require().Equal("active", got.SessionWindowStatus)
|
||||
}
|
||||
|
||||
// --- UpdateExtra ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "acc-extra",
|
||||
Extra: model.JSONB{"a": "1"},
|
||||
})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("1", got.Extra["a"])
|
||||
s.Require().Equal("2", got.Extra["b"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
crsID := "crs-12345"
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "acc-crs",
|
||||
Extra: model.JSONB{"crs_account_id": crsID},
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got)
|
||||
s.Require().Equal("acc-crs", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
// --- BulkUpdate ---
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{
|
||||
Priority: &newPriority,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
|
||||
s.Require().Equal(99, got1.Priority)
|
||||
s.Require().Equal(99, got2.Priority)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "bulk-cred",
|
||||
Credentials: model.JSONB{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
|
||||
Credentials: model.JSONB{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("value", got.Credentials["existing"])
|
||||
s.Require().Equal("new_value", got.Credentials["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
Name: "bulk-extra",
|
||||
Extra: model.JSONB{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
|
||||
Extra: model.JSONB{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("val", got.Extra["existing"])
|
||||
s.Require().Equal("new_val", got.Extra["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func idsOfAccounts(accounts []model.Account) []int64 {
|
||||
out := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, accounts[i].ID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
125
backend/internal/repository/api_key_cache_integration_test.go
Normal file
125
backend/internal/repository/api_key_cache_integration_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "increment_increases_count_and_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||
require.Equal(s.T(), 2, count, "count mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "increment_increases_count",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||
|
||||
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||
require.NoError(s.T(), err, "Get dailyKey")
|
||||
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_expiry_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test-expiry"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||
require.NoError(s.T(), err, "TTL dailyKey")
|
||||
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
355
backend/internal/repository/api_key_repo_integration_test.go
Normal file
355
backend/internal/repository/api_key_repo_integration_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *ApiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewApiKeyRepository(s.db)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||
|
||||
key := &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = model.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(model.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
})
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
|
||||
for i := 0; i < 5; i++ {
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-page-" + string(rune('a'+i)),
|
||||
Name: "Key",
|
||||
})
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
// User preloaded
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), count)
|
||||
}
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
|
||||
|
||||
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().Nil(got1.GroupID)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-1",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User)
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = model.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(model.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-2",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||
|
||||
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
@@ -0,0 +1,283 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &ports.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &ports.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &ports.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
@@ -16,20 +16,28 @@ import (
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type claudeOAuthService struct{}
|
||||
|
||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||
return &claudeOAuthService{}
|
||||
return &claudeOAuthService{
|
||||
baseURL: "https://claude.ai",
|
||||
tokenURL: oauth.TokenURL,
|
||||
clientFactory: createReqClient,
|
||||
}
|
||||
}
|
||||
|
||||
type claudeOAuthService struct {
|
||||
baseURL string
|
||||
tokenURL string
|
||||
clientFactory func(proxyURL string) *req.Client
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
@@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"response_type": "code",
|
||||
@@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
authCode := code
|
||||
@@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
@@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
@@ -189,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
@@ -202,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
@@ -226,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client {
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func prefix(s string, n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
343
backend/internal/repository/claude_oauth_service_test.go
Normal file
343
backend/internal/repository/claude_oauth_service_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
formValues url.Values
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
errContain string
|
||||
wantUUID string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||
},
|
||||
wantUUID: "org-1",
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContain: "401",
|
||||
},
|
||||
{
|
||||
name: "invalid_json_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
if tt.errContain != "" {
|
||||
require.ErrorContains(s.T(), err, tt.errContain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantUUID, got)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantCode string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "parses_redirect_uri",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||
})
|
||||
},
|
||||
wantCode: "AUTH#STATE",
|
||||
validate: func(captured requestCapture) {
|
||||
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_code_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantCode, code)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "rt",
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_form",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
|
||||
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
|
||||
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
captured.formValues, _ = url.ParseQuery(string(captured.body))
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||
}
|
||||
@@ -12,10 +12,14 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUsageService struct{}
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{}
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
@@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
105
backend/internal/repository/claude_usage_service_test.go
Normal file
105
backend/internal/repository/claude_usage_service_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeUsageServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
fetcher *claudeUsageService
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||
type usageRequestCapture struct {
|
||||
authorization string
|
||||
anthropicBeta string
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||
|
||||
// Assertions on captured request data
|
||||
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
require.ErrorContains(s.T(), err, "nope")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "decode response failed")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache ports.ConcurrencyCache
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
92
backend/internal/repository/email_cache_integration_test.go
Normal file
92
backend/internal/repository/email_cache_integration_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache ports.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
172
backend/internal/repository/fixtures_integration_test.go
Normal file
172
backend/internal/repository/fixtures_integration_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
|
||||
t.Helper()
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = model.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = model.StatusActive
|
||||
}
|
||||
if u.CreatedAt.IsZero() {
|
||||
u.CreatedAt = time.Now()
|
||||
}
|
||||
if u.UpdatedAt.IsZero() {
|
||||
u.UpdatedAt = u.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(u).Error, "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
|
||||
t.Helper()
|
||||
if g.Platform == "" {
|
||||
g.Platform = model.PlatformAnthropic
|
||||
}
|
||||
if g.Status == "" {
|
||||
g.Status = model.StatusActive
|
||||
}
|
||||
if g.SubscriptionType == "" {
|
||||
g.SubscriptionType = model.SubscriptionTypeStandard
|
||||
}
|
||||
if g.CreatedAt.IsZero() {
|
||||
g.CreatedAt = time.Now()
|
||||
}
|
||||
if g.UpdatedAt.IsZero() {
|
||||
g.UpdatedAt = g.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(g).Error, "create group")
|
||||
return g
|
||||
}
|
||||
|
||||
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
|
||||
t.Helper()
|
||||
if p.Protocol == "" {
|
||||
p.Protocol = "http"
|
||||
}
|
||||
if p.Host == "" {
|
||||
p.Host = "127.0.0.1"
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = 8080
|
||||
}
|
||||
if p.Status == "" {
|
||||
p.Status = model.StatusActive
|
||||
}
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = time.Now()
|
||||
}
|
||||
if p.UpdatedAt.IsZero() {
|
||||
p.UpdatedAt = p.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(p).Error, "create proxy")
|
||||
return p
|
||||
}
|
||||
|
||||
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
|
||||
t.Helper()
|
||||
if a.Platform == "" {
|
||||
a.Platform = model.PlatformAnthropic
|
||||
}
|
||||
if a.Type == "" {
|
||||
a.Type = model.AccountTypeOAuth
|
||||
}
|
||||
if a.Status == "" {
|
||||
a.Status = model.StatusActive
|
||||
}
|
||||
if !a.Schedulable {
|
||||
a.Schedulable = true
|
||||
}
|
||||
if a.Credentials == nil {
|
||||
a.Credentials = model.JSONB{}
|
||||
}
|
||||
if a.Extra == nil {
|
||||
a.Extra = model.JSONB{}
|
||||
}
|
||||
if a.CreatedAt.IsZero() {
|
||||
a.CreatedAt = time.Now()
|
||||
}
|
||||
if a.UpdatedAt.IsZero() {
|
||||
a.UpdatedAt = a.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(a).Error, "create account")
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
|
||||
t.Helper()
|
||||
if k.Status == "" {
|
||||
k.Status = model.StatusActive
|
||||
}
|
||||
if k.CreatedAt.IsZero() {
|
||||
k.CreatedAt = time.Now()
|
||||
}
|
||||
if k.UpdatedAt.IsZero() {
|
||||
k.UpdatedAt = k.CreatedAt
|
||||
}
|
||||
require.NoError(t, db.Create(k).Error, "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
|
||||
t.Helper()
|
||||
if c.Status == "" {
|
||||
c.Status = model.StatusUnused
|
||||
}
|
||||
if c.Type == "" {
|
||||
c.Type = model.RedeemTypeBalance
|
||||
}
|
||||
if c.CreatedAt.IsZero() {
|
||||
c.CreatedAt = time.Now()
|
||||
}
|
||||
require.NoError(t, db.Create(c).Error, "create redeem code")
|
||||
return c
|
||||
}
|
||||
|
||||
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
|
||||
t.Helper()
|
||||
if s.Status == "" {
|
||||
s.Status = model.SubscriptionStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if s.StartsAt.IsZero() {
|
||||
s.StartsAt = now.Add(-1 * time.Hour)
|
||||
}
|
||||
if s.ExpiresAt.IsZero() {
|
||||
s.ExpiresAt = now.Add(24 * time.Hour)
|
||||
}
|
||||
if s.AssignedAt.IsZero() {
|
||||
s.AssignedAt = now
|
||||
}
|
||||
if s.CreatedAt.IsZero() {
|
||||
s.CreatedAt = now
|
||||
}
|
||||
if s.UpdatedAt.IsZero() {
|
||||
s.UpdatedAt = now
|
||||
}
|
||||
require.NoError(t, db.Create(s).Error, "create user subscription")
|
||||
return s
|
||||
}
|
||||
|
||||
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
|
||||
t.Helper()
|
||||
require.NoError(t, db.Create(&model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: priority,
|
||||
}).Error, "create account_group")
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GatewayCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache ports.GatewayCache
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGatewayCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
sessionID := "s1"
|
||||
accountID := int64(99)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
sessionID := "s2"
|
||||
accountID := int64(100)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
sessionID := "s3"
|
||||
accountID := int64(101)
|
||||
initialTTL := 1 * time.Minute
|
||||
refreshTTL := 3 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after Refresh")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
|
||||
// Set a non-integer value
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||
require.Error(s.T(), err, "expected error for corrupted value")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
328
backend/internal/repository/github_release_service_test.go
Normal file
328
backend/internal/repository/github_release_service_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GitHubReleaseServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *githubReleaseClient
|
||||
tempDir string
|
||||
}
|
||||
|
||||
// testTransport redirects requests to the test server
|
||||
type testTransport struct {
|
||||
testServerURL string
|
||||
}
|
||||
|
||||
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the URL to point to our test server
|
||||
testURL := t.testServerURL + req.URL.Path
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||
s.tempDir = s.T().TempDir()
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "100")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized chunked download")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||
require.NoError(s.T(), err, "expected success")
|
||||
|
||||
b, err := os.ReadFile(dest)
|
||||
require.NoError(s.T(), err, "read")
|
||||
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
|
||||
require.Len(s.T(), b, 100, "downloaded content length mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for 404")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for 404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.Error(s.T(), err, "expected error for non-200")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "cancelled.bin")
|
||||
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
// Use a path that cannot be created (directory doesn't exist)
|
||||
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid destination path")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
releaseJSON := `{
|
||||
"tag_name": "v1.0.0",
|
||||
"name": "Release 1.0.0",
|
||||
"body": "Release notes",
|
||||
"html_url": "https://github.com/test/repo/releases/v1.0.0",
|
||||
"assets": [
|
||||
{
|
||||
"name": "app-linux-amd64.tar.gz",
|
||||
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(releaseJSON))
|
||||
}))
|
||||
|
||||
// Use custom transport to redirect requests to test server
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v1.0.0", release.TagName)
|
||||
require.Equal(s.T(), "Release 1.0.0", release.Name)
|
||||
require.Len(s.T(), release.Assets, 1)
|
||||
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
require.Contains(s.T(), err.Error(), "404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestGitHubReleaseServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(GitHubReleaseServiceSuite))
|
||||
}
|
||||
244
backend/internal/repository/group_repo_integration_test.go
Normal file
244
backend/internal/repository/group_repo_integration_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *GroupRepository
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewGroupRepository(s.db)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(GroupRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *GroupRepoSuite) TestCreate() {
|
||||
group := &model.Group{
|
||||
Name: "test-create",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, group)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(group.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
|
||||
|
||||
group.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, group)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(groups, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(model.StatusDisabled, groups[0].Status)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||
Name: "g1",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||
Name: "g2",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
IsExclusive: true,
|
||||
})
|
||||
|
||||
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
|
||||
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
|
||||
}
|
||||
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal("active1", groups[0].Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
|
||||
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListActiveByPlatform")
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal("g1", groups[0].Name)
|
||||
}
|
||||
|
||||
// --- ExistsByName ---
|
||||
|
||||
func (s *GroupRepoSuite) TestExistsByName() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
|
||||
|
||||
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||
s.Require().NoError(err, "ExistsByName")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- GetAccountCount ---
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), affected)
|
||||
|
||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- DB ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDB() {
|
||||
db := s.repo.DB()
|
||||
s.Require().NotNil(db, "DB should return non-nil")
|
||||
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
|
||||
}
|
||||
115
backend/internal/repository/http_upstream_test.go
Normal file
115
backend/internal/repository/http_upstream_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type HTTPUpstreamSuite struct {
|
||||
suite.Suite
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||
s.cfg = &config.Config{}
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
transport, ok := svc.defaultClient.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
transport, ok := svc.defaultClient.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
|
||||
got := svc.createProxyClient("://bad-proxy-url")
|
||||
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "")
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct", string(b), "unexpected body")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
seen := make(chan string, 1)
|
||||
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
s.T().Cleanup(proxySrv.Close)
|
||||
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, proxySrv.URL)
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "proxied", string(b), "unexpected body")
|
||||
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct-empty")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "")
|
||||
require.NoError(s.T(), err, "Do with empty proxy")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct-empty", string(b))
|
||||
}
|
||||
|
||||
func TestHTTPUpstreamSuite(t *testing.T) {
|
||||
suite.Run(t, new(HTTPUpstreamSuite))
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type IdentityCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *identityCache
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewIdentityCache(s.rdb).(*identityCache)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.NoError(s.T(), err, "GetFingerprint")
|
||||
require.Equal(s.T(), "c1", gotFP.ClientID)
|
||||
require.Equal(s.T(), "ua", gotFP.UserAgent)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
|
||||
require.NoError(s.T(), err, "TTL fpKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 999)
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
|
||||
err := s.cache.SetFingerprint(s.ctx, 100, nil)
|
||||
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
|
||||
}
|
||||
|
||||
func TestIdentityCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(IdentityCacheSuite))
|
||||
}
|
||||
369
backend/internal/repository/integration_harness_test.go
Normal file
369
backend/internal/repository/integration_harness_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
redisclient "github.com/redis/go-redis/v9"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
gormpostgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
redisImageTag = "redis:8.4-alpine"
|
||||
postgresImageTag = "postgres:18.1-alpine3.23"
|
||||
)
|
||||
|
||||
var (
|
||||
integrationDB *gorm.DB
|
||||
integrationRedis *redisclient.Client
|
||||
|
||||
redisNamespaceSeq uint64
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := timezone.Init("UTC"); err != nil {
|
||||
log.Printf("failed to init timezone: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !dockerIsAvailable(ctx) {
|
||||
// In CI we expect Docker to be available so integration tests should fail loudly.
|
||||
if os.Getenv("CI") != "" {
|
||||
log.Printf("docker is not available (CI=true); failing integration tests")
|
||||
os.Exit(1)
|
||||
}
|
||||
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
postgresImage := selectDockerImage(ctx, postgresImageTag)
|
||||
pgContainer, err := tcpostgres.Run(
|
||||
ctx,
|
||||
postgresImage,
|
||||
tcpostgres.WithDatabase("sub2api_test"),
|
||||
tcpostgres.WithUsername("postgres"),
|
||||
tcpostgres.WithPassword("postgres"),
|
||||
tcpostgres.BasicWaitStrategies(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start postgres container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = pgContainer.Terminate(ctx) }()
|
||||
|
||||
redisContainer, err := tcredis.Run(
|
||||
ctx,
|
||||
redisImageTag,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start redis container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = redisContainer.Terminate(ctx) }()
|
||||
|
||||
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
|
||||
if err != nil {
|
||||
log.Printf("failed to get postgres dsn: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("failed to open gorm db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := model.AutoMigrate(integrationDB); err != nil {
|
||||
log.Printf("failed to automigrate db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis host: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis port: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationRedis = redisclient.NewClient(&redisclient.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
if err := integrationRedis.Ping(ctx).Err(); err != nil {
|
||||
log.Printf("failed to ping redis: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
_ = integrationRedis.Close()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func dockerIsAvailable(ctx context.Context) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "info")
|
||||
cmd.Env = os.Environ()
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func selectDockerImage(ctx context.Context, preferred string) string {
|
||||
if dockerImageExists(ctx, preferred) {
|
||||
return preferred
|
||||
}
|
||||
|
||||
return preferred
|
||||
}
|
||||
|
||||
func dockerImageExists(ctx context.Context, image string) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
|
||||
}
|
||||
|
||||
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
|
||||
pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return db.PingContext(pingCtx)
|
||||
}
|
||||
|
||||
func testTx(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
tx := integrationDB.Begin()
|
||||
require.NoError(t, tx.Error, "begin tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback().Error
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
func testRedis(t *testing.T) *redisclient.Client {
|
||||
t.Helper()
|
||||
|
||||
prefix := fmt.Sprintf(
|
||||
"it:%s:%d:%d:",
|
||||
sanitizeRedisNamespace(t.Name()),
|
||||
time.Now().UnixNano(),
|
||||
atomic.AddUint64(&redisNamespaceSeq, 1),
|
||||
)
|
||||
|
||||
opts := *integrationRedis.Options()
|
||||
rdb := redisclient.NewClient(&opts)
|
||||
rdb.AddHook(prefixHook{prefix: prefix})
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx := context.Background()
|
||||
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
|
||||
require.NoError(t, err, "scan redis keys for cleanup")
|
||||
if len(keys) > 0 {
|
||||
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
|
||||
t.Helper()
|
||||
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
|
||||
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
|
||||
}
|
||||
|
||||
func sanitizeRedisNamespace(name string) string {
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, " ", "_")
|
||||
return name
|
||||
}
|
||||
|
||||
type prefixHook struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
|
||||
|
||||
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
|
||||
return func(ctx context.Context, cmd redisclient.Cmder) error {
|
||||
h.prefixCmd(cmd)
|
||||
return next(ctx, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redisclient.Cmder) error {
|
||||
for _, cmd := range cmds {
|
||||
h.prefixCmd(cmd)
|
||||
}
|
||||
return next(ctx, cmds)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
|
||||
args := cmd.Args()
|
||||
if len(args) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
prefixOne := func(i int) {
|
||||
if i < 0 || i >= len(args) {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := args[i].(type) {
|
||||
case string:
|
||||
if v != "" && !strings.HasPrefix(v, h.prefix) {
|
||||
args[i] = h.prefix + v
|
||||
}
|
||||
case []byte:
|
||||
s := string(v)
|
||||
if s != "" && !strings.HasPrefix(s, h.prefix) {
|
||||
args[i] = []byte(h.prefix + s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(cmd.Name()) {
|
||||
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
|
||||
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
|
||||
prefixOne(1)
|
||||
case "del", "unlink":
|
||||
for i := 1; i < len(args); i++ {
|
||||
prefixOne(i)
|
||||
}
|
||||
case "eval", "evalsha", "eval_ro", "evalsha_ro":
|
||||
if len(args) < 3 {
|
||||
return
|
||||
}
|
||||
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
|
||||
if err != nil || numKeys <= 0 {
|
||||
return
|
||||
}
|
||||
for i := 0; i < numKeys && 3+i < len(args); i++ {
|
||||
prefixOne(3 + i)
|
||||
}
|
||||
case "scan":
|
||||
for i := 2; i+1 < len(args); i++ {
|
||||
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
|
||||
prefixOne(i + 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationRedisSuite provides a base suite for Redis integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and rdb.
|
||||
type IntegrationRedisSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
rdb *redisclient.Client
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and rdb for each test method.
|
||||
func (s *IntegrationRedisSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.rdb = testRedis(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
|
||||
// AssertTTLWithin asserts that ttl is within [min, max].
|
||||
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
|
||||
s.T().Helper()
|
||||
assertTTLWithin(s.T(), ttl, min, max)
|
||||
}
|
||||
|
||||
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and db.
|
||||
type IntegrationDBSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and db for each test method.
|
||||
func (s *IntegrationDBSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
@@ -12,11 +12,13 @@ import (
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type openaiOAuthService struct{}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{}
|
||||
return &openaiOAuthService{tokenURL: openai.TokenURL}
|
||||
}
|
||||
|
||||
type openaiOAuthService struct {
|
||||
tokenURL string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
@@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(openai.TokenURL)
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
@@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(openai.TokenURL)
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
|
||||
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type OpenAIOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
svc *openaiOAuthService
|
||||
received chan url.Values
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.received = make(chan url.Values, 1)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
errCh <- "method mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code"); got != "code" {
|
||||
errCh <- "code mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code_verifier"); got != "ver" {
|
||||
errCh <- "code_verifier mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("refresh_token"); got != "rt" {
|
||||
errCh <- "refresh_token mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
|
||||
errCh <- "scope mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at2", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "bad")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 400")
|
||||
require.ErrorContains(s.T(), err, "bad")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "request failed")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
want := "http://localhost:9999/cb"
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if got := r.PostForm.Get("redirect_uri"); got != want {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
s.received <- r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
s.svc.tokenURL = s.srv.URL + "?x=1"
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case <-s.received:
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "unauthorized")
|
||||
}))
|
||||
|
||||
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.Error(s.T(), err, "expected error for non-2xx status")
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
147
backend/internal/repository/pricing_service_test.go
Normal file
147
backend/internal/repository/pricing_service_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PricingServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
client *pricingRemoteClient
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ok" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
|
||||
require.NoError(s.T(), err, "FetchPricingJSON")
|
||||
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/hashfile":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
|
||||
case "/hashonly":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("def456\n"))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "abc123", hash, "hash mismatch")
|
||||
|
||||
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "def456", hash2, "hash mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// empty body
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
|
||||
require.NoError(s.T(), err, "FetchHashText empty body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(" \n"))
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
|
||||
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestPricingServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(PricingServiceSuite))
|
||||
}
|
||||
@@ -16,10 +16,14 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type proxyProbeService struct{}
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
return &proxyProbeService{}
|
||||
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
|
||||
}
|
||||
|
||||
const defaultIPInfoURL = "https://ipinfo.io/json"
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
121
backend/internal/repository/proxy_probe_service_test.go
Normal file
121
backend/internal/repository/proxy_probe_service_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ProxyProbeServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
proxySrv *httptest.Server
|
||||
prober *proxyProbeService
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
if s.proxySrv != nil {
|
||||
s.proxySrv.Close()
|
||||
s.proxySrv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
|
||||
_, err := createProxyTransport("://bad")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "invalid proxy URL")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
|
||||
_, err := createProxyTransport("ftp://127.0.0.1:1")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
|
||||
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
|
||||
require.NoError(s.T(), err, "createProxyTransport")
|
||||
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
seen := make(chan string, 1)
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "c", info.City)
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
|
||||
// Verify proxy received the request
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status: 503")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to parse response")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
||||
s.prober.ipInfoURL = "://invalid-url"
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.proxySrv.Close()
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||
}
|
||||
|
||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||
}
|
||||
302
backend/internal/repository/proxy_repo_integration_test.go
Normal file
302
backend/internal/repository/proxy_repo_integration_test.go
Normal file
@@ -0,0 +1,302 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *ProxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewProxyRepository(s.db)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCreate() {
|
||||
proxy := &model.Proxy{
|
||||
Name: "test-create",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(proxy.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestUpdate() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
|
||||
|
||||
proxy.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestDelete() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestList() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||
|
||||
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(proxies, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("socks5", proxies[0].Protocol)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Contains(proxies[0].Name, "production")
|
||||
}
|
||||
|
||||
// --- ListActive ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActive() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
|
||||
|
||||
proxies, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("active1", proxies[0].Name)
|
||||
}
|
||||
|
||||
// --- ExistsByHostPortAuth ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p-noauth",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(exists)
|
||||
}
|
||||
|
||||
// --- CountAccountsByProxyID ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- GetAccountCountsForProxies ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(counts)
|
||||
}
|
||||
|
||||
// --- ListActiveWithAccountCount ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p1",
|
||||
Status: model.StatusActive,
|
||||
CreatedAt: base.Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p2",
|
||||
Status: model.StatusActive,
|
||||
CreatedAt: base,
|
||||
})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p3-inactive",
|
||||
Status: model.StatusDisabled,
|
||||
})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 active proxies")
|
||||
|
||||
// Sorted by created_at DESC, so p2 first
|
||||
s.Require().Equal(p2.ID, withCounts[0].ID)
|
||||
s.Require().Equal(int64(1), withCounts[0].AccountCount)
|
||||
s.Require().Equal(p1.ID, withCounts[1].ID)
|
||||
s.Require().Equal(int64(2), withCounts[1].AccountCount)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
CreatedAt: time.Now().Add(-1 * time.Hour),
|
||||
UpdatedAt: time.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
Name: "p2",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists, "expected proxy to exist")
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 proxies")
|
||||
for _, pc := range withCounts {
|
||||
switch pc.ID {
|
||||
case p1.ID:
|
||||
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
|
||||
case p2.ID:
|
||||
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
|
||||
default:
|
||||
s.Require().Fail("unexpected proxy id", pc.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
105
backend/internal/repository/redeem_cache_integration_test.go
Normal file
105
backend/internal/repository/redeem_cache_integration_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RedeemCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *redeemCache
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
|
||||
missingUserID := int64(99999)
|
||||
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
|
||||
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetRedeemAttemptCount")
|
||||
require.Equal(s.T(), 1, count, "count mismatch")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestMultipleIncrements() {
|
||||
userID := int64(2)
|
||||
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, count, "count after 3 increments")
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Second acquire should fail
|
||||
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock 2")
|
||||
require.False(s.T(), ok, "expected lock to be held")
|
||||
|
||||
// Release
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
|
||||
|
||||
// Now acquire should succeed
|
||||
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock after release")
|
||||
require.True(s.T(), ok)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
|
||||
lockKey := redeemLockKeyPrefix + "CODE2"
|
||||
lockTTL := 15 * time.Second
|
||||
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
|
||||
require.NoError(s.T(), err, "TTL lock key")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
|
||||
// Release a lock that doesn't exist should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
|
||||
|
||||
// Acquire, release, release again
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
|
||||
}
|
||||
|
||||
func TestRedeemCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCacheSuite))
|
||||
}
|
||||
315
backend/internal/repository/redeem_code_repo_integration_test.go
Normal file
315
backend/internal/repository/redeem_code_repo_integration_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RedeemCodeRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *RedeemCodeRepository
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewRedeemCodeRepository(s.db)
|
||||
}
|
||||
|
||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCodeRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
code := &model.RedeemCode{
|
||||
Code: "TEST-CREATE",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Value: 100,
|
||||
Status: model.StatusUnused,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, code)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(code.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("TEST-CREATE", got.Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||
codes := []model.RedeemCode{
|
||||
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
|
||||
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
|
||||
}
|
||||
|
||||
err := s.repo.CreateBatch(s.ctx, codes)
|
||||
s.Require().NoError(err, "CreateBatch")
|
||||
|
||||
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(10), got1.Value)
|
||||
|
||||
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(20), got2.Value)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
|
||||
|
||||
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
s.Require().Equal("GET-BY-CODE", got.Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
|
||||
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
|
||||
s.Require().Error(err, "expected error for non-existent code")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
|
||||
|
||||
err := s.repo.Delete(s.ctx, code.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestList() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
|
||||
|
||||
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(codes, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(model.StatusUsed, codes[0].Status)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Contains(codes[0].Code, "ALPHA")
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
Code: "WITH-GROUP",
|
||||
Type: model.RedeemTypeSubscription,
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().NotNil(codes[0].Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
|
||||
|
||||
code.Value = 50
|
||||
err := s.repo.Update(s.ctx, code)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(50), got.Value)
|
||||
}
|
||||
|
||||
// --- Use ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.StatusUsed, got.Status)
|
||||
s.Require().NotNil(got.UsedBy)
|
||||
s.Require().Equal(user.ID, *got.UsedBy)
|
||||
s.Require().NotNil(got.UsedAt)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use first time")
|
||||
|
||||
// Second use should fail
|
||||
err = s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "expected error for already used code")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create codes with explicit used_at for ordering
|
||||
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
Code: "USER-1",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c1).Update("used_at", base)
|
||||
|
||||
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
Code: "USER-2",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(codes, 2)
|
||||
// Ordered by used_at DESC, so USER-2 first
|
||||
s.Require().Equal("USER-2", codes[0].Code)
|
||||
s.Require().Equal("USER-1", codes[1].Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
|
||||
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
Code: "WITH-GRP",
|
||||
Type: model.RedeemTypeSubscription,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
s.db.Model(c).Update("used_at", time.Now())
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().NotNil(codes[0].Group)
|
||||
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
Code: "DEF-LIM",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c).Update("used_at", time.Now())
|
||||
|
||||
// limit <= 0 should default to 10
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
|
||||
|
||||
codes := []model.RedeemCode{
|
||||
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
|
||||
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
|
||||
}
|
||||
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
|
||||
|
||||
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(list, 1)
|
||||
s.Require().NotNil(list[0].Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, list[0].Group.ID)
|
||||
|
||||
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
|
||||
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
|
||||
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
|
||||
// Use fixed time instead of time.Sleep for deterministic ordering
|
||||
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
|
||||
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
|
||||
|
||||
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(used, 2, "expected 2 used codes")
|
||||
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
|
||||
}
|
||||
108
backend/internal/repository/setting_repo_integration_test.go
Normal file
108
backend/internal/repository/setting_repo_integration_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SettingRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *SettingRepository
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewSettingRepository(s.db)
|
||||
}
|
||||
|
||||
func TestSettingRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(SettingRepoSuite))
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetAndGetValue() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
|
||||
got, err := s.repo.GetValue(s.ctx, "k1")
|
||||
s.Require().NoError(err, "GetValue")
|
||||
s.Require().Equal("v1", got, "GetValue mismatch")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSet_Upsert() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
|
||||
got, err := s.repo.GetValue(s.ctx, "k1")
|
||||
s.Require().NoError(err, "GetValue after upsert")
|
||||
s.Require().Equal("v2", got, "upsert mismatch")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetValue_Missing() {
|
||||
_, err := s.repo.GetValue(s.ctx, "nonexistent")
|
||||
s.Require().Error(err, "expected error for missing key")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
|
||||
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
|
||||
s.Require().NoError(err, "GetMultiple")
|
||||
s.Require().Equal("v2", m["k2"])
|
||||
s.Require().Equal("v3", m["k3"])
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
|
||||
m, err := s.repo.GetMultiple(s.ctx, []string{})
|
||||
s.Require().NoError(err, "GetMultiple with empty keys")
|
||||
s.Require().Empty(m, "expected empty map")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
|
||||
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
|
||||
s.Require().NoError(err, "GetMultiple subset")
|
||||
s.Require().Equal("1", m["a"])
|
||||
s.Require().Equal("3", m["c"])
|
||||
_, exists := m["nonexistent"]
|
||||
s.Require().False(exists, "nonexistent key should not be in map")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetAll() {
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
|
||||
all, err := s.repo.GetAll(s.ctx)
|
||||
s.Require().NoError(err, "GetAll")
|
||||
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
|
||||
s.Require().Equal("1", all["x"])
|
||||
s.Require().Equal("2", all["y"])
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestDelete() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
|
||||
_, err := s.repo.GetValue(s.ctx, "todelete")
|
||||
s.Require().Error(err, "expected missing key error after Delete")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestDelete_Idempotent() {
|
||||
// Delete a key that doesn't exist should not error
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
|
||||
|
||||
got, err := s.repo.GetValue(s.ctx, "upsert_key")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
|
||||
|
||||
got2, err := s.repo.GetValue(s.ctx, "new_key")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("new_val", got2)
|
||||
}
|
||||
@@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev
|
||||
|
||||
type turnstileVerifier struct {
|
||||
httpClient *http.Client
|
||||
verifyURL string
|
||||
}
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
@@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
verifyURL: turnstileVerifyURL,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
143
backend/internal/repository/turnstile_service_test.go
Normal file
143
backend/internal/repository/turnstile_service_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type TurnstileServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
verifier *turnstileVerifier
|
||||
received chan url.Values
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.received = make(chan url.Values, 1)
|
||||
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.verifier = verifier
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.verifier.verifyURL = s.srv.URL
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture form data in main goroutine context later
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
s.received <- values
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
}))
|
||||
|
||||
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.NoError(s.T(), err, "VerifyToken")
|
||||
require.NotNil(s.T(), resp)
|
||||
require.True(s.T(), resp.Success, "expected success response")
|
||||
|
||||
// Assert form fields in main goroutine
|
||||
select {
|
||||
case values := <-s.received:
|
||||
require.Equal(s.T(), "sk", values.Get("secret"))
|
||||
require.Equal(s.T(), "token", values.Get("response"))
|
||||
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||
var contentType string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
contentType = r.Header.Get("Content-Type")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
}))
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
s.received <- values
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
}))
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
select {
|
||||
case values := <-s.received:
|
||||
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.Error(s.T(), err, "expected error when server is closed")
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
|
||||
Success: false,
|
||||
ErrorCodes: []string{"invalid-input-response"},
|
||||
})
|
||||
}))
|
||||
|
||||
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
|
||||
require.NotNil(s.T(), resp)
|
||||
require.False(s.T(), resp.Success)
|
||||
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
|
||||
}
|
||||
|
||||
func TestTurnstileServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(TurnstileServiceSuite))
|
||||
}
|
||||
73
backend/internal/repository/update_cache_integration_test.go
Normal file
73
backend/internal/repository/update_cache_integration_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UpdateCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *updateCache
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewUpdateCache(s.rdb).(*updateCache)
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
|
||||
_, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
|
||||
updateTTL := 5 * time.Minute
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err, "GetUpdateInfo")
|
||||
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
|
||||
updateTTL := 5 * time.Minute
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||
require.NoError(s.T(), err, "TTL updateCacheKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
|
||||
// TTL=0 means persist forever (no expiry) in Redis SET command
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v0.0.0", info)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
|
||||
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
|
||||
}
|
||||
|
||||
func TestUpdateCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(UpdateCacheSuite))
|
||||
}
|
||||
890
backend/internal/repository/usage_log_repo_integration_test.go
Normal file
890
backend/internal/repository/usage_log_repo_integration_test.go
Normal file
@@ -0,0 +1,890 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UsageLogRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *UsageLogRepository
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUsageLogRepository(s.db)
|
||||
}
|
||||
|
||||
func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
|
||||
log := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalCost: cost,
|
||||
ActualCost: cost,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log))
|
||||
return log
|
||||
}
|
||||
|
||||
// --- Create / GetByID ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
|
||||
|
||||
log := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.4,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(log.ID, got.ID)
|
||||
s.Require().Equal(10, got.InputTokens)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
err := s.repo.Delete(s.ctx, log.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, log.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
logs, page, err := s.repo.ListByUser(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
// --- ListByApiKey ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByApiKey")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
// --- ListByAccount ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
logs, page, err := s.repo.ListByAccount(s.ctx, account.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByAccount")
|
||||
s.Require().Len(logs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
}
|
||||
|
||||
// --- GetUserStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
stats, err := s.repo.GetUserStats(s.ctx, user.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "GetUserStats")
|
||||
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||
s.Require().Equal(int64(25), stats.InputTokens)
|
||||
s.Require().Equal(int64(45), stats.OutputTokens)
|
||||
}
|
||||
|
||||
// --- ListWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
filters := usagestats.UsageLogFilters{UserID: user.ID}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Len(logs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
}
|
||||
|
||||
// --- GetDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
now := time.Now()
|
||||
todayStart := timezone.Today()
|
||||
|
||||
userToday := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "today@example.com",
|
||||
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||
UpdatedAt: now,
|
||||
})
|
||||
userOld := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "old@example.com",
|
||||
CreatedAt: todayStart.Add(-24 * time.Hour),
|
||||
UpdatedAt: todayStart.Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
|
||||
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
|
||||
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logToday := &model.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
GroupID: &group.ID,
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 3,
|
||||
CacheReadTokens: 4,
|
||||
TotalCost: 1.5,
|
||||
ActualCost: 1.2,
|
||||
DurationMs: &d1,
|
||||
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
|
||||
|
||||
logOld := &model.UsageLog{
|
||||
UserID: userOld.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
OutputTokens: 6,
|
||||
TotalCost: 0.7,
|
||||
ActualCost: 0.7,
|
||||
DurationMs: &d2,
|
||||
CreatedAt: todayStart.Add(-1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
|
||||
|
||||
logPerf := &model.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 1,
|
||||
OutputTokens: 2,
|
||||
TotalCost: 0.1,
|
||||
ActualCost: 0.1,
|
||||
DurationMs: &d3,
|
||||
CreatedAt: now.Add(-30 * time.Second),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
|
||||
|
||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats")
|
||||
|
||||
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
|
||||
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
|
||||
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
|
||||
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
|
||||
|
||||
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
|
||||
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
|
||||
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
|
||||
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
|
||||
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
|
||||
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
||||
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
|
||||
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
||||
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
||||
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
||||
|
||||
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
|
||||
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
|
||||
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||
}
|
||||
|
||||
// --- GetUserDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetUserDashboardStats")
|
||||
s.Require().Equal(int64(1), stats.TotalApiKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalRequests)
|
||||
}
|
||||
|
||||
// --- GetAccountTodayStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetAccountTodayStats")
|
||||
s.Require().Equal(int64(1), stats.Requests)
|
||||
s.Require().Equal(int64(30), stats.Tokens)
|
||||
}
|
||||
|
||||
// --- GetBatchUserUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
|
||||
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
|
||||
s.Require().NoError(err, "GetBatchUserUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
s.Require().NotNil(stats[user1.ID])
|
||||
s.Require().NotNil(stats[user2.ID])
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
// --- GetBatchApiKeyUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
|
||||
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
// --- GetGlobalStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
|
||||
stats, err := s.repo.GetGlobalStats(s.ctx, base.Add(-1*time.Hour), base.Add(2*time.Hour))
|
||||
s.Require().NoError(err, "GetGlobalStats")
|
||||
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||
s.Require().Equal(int64(25), stats.TotalInputTokens)
|
||||
s.Require().Equal(int64(45), stats.TotalOutputTokens)
|
||||
}
|
||||
|
||||
func maxTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// --- ListByUserAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByUserAndTimeRange(s.ctx, user.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByUserAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- ListByApiKeyAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(30*time.Minute))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- ListByAccountAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(45*time.Minute))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByAccountAndTimeRange(s.ctx, account.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByAccountAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- ListByModelAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create logs with different models
|
||||
log1 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 15,
|
||||
OutputTokens: 25,
|
||||
TotalCost: 0.6,
|
||||
ActualCost: 0.6,
|
||||
CreatedAt: base.Add(30 * time.Minute),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
|
||||
log3 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 20,
|
||||
OutputTokens: 30,
|
||||
TotalCost: 0.7,
|
||||
ActualCost: 0.7,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log3))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByModelAndTimeRange(s.ctx, "claude-3-opus", startTime, endTime)
|
||||
s.Require().NoError(err, "ListByModelAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- GetAccountWindowStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-10 * time.Minute)
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, now.Add(-5*time.Minute))
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, now.Add(-3*time.Minute))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, now.Add(-30*time.Minute)) // outside window
|
||||
|
||||
stats, err := s.repo.GetAccountWindowStats(s.ctx, account.ID, windowStart)
|
||||
s.Require().NoError(err, "GetAccountWindowStats")
|
||||
s.Require().Equal(int64(2), stats.Requests)
|
||||
s.Require().Equal(int64(70), stats.Tokens) // (10+20) + (15+25)
|
||||
}
|
||||
|
||||
// --- GetUserUsageTrendByUserID ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(24*time.Hour)) // next day
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "day")
|
||||
s.Require().NoError(err, "GetUserUsageTrendByUserID")
|
||||
s.Require().Len(trend, 2) // 2 different days
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(2*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "hour")
|
||||
s.Require().NoError(err, "GetUserUsageTrendByUserID hourly")
|
||||
s.Require().Len(trend, 3) // 3 different hours
|
||||
}
|
||||
|
||||
// --- GetUserModelStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create logs with different models
|
||||
log1 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
OutputTokens: 100,
|
||||
TotalCost: 0.2,
|
||||
ActualCost: 0.2,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
stats, err := s.repo.GetUserModelStats(s.ctx, user.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "GetUserModelStats")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Should be ordered by total_tokens DESC
|
||||
s.Require().Equal("claude-3-opus", stats[0].Model)
|
||||
s.Require().Equal(int64(300), stats[0].TotalTokens)
|
||||
}
|
||||
|
||||
// --- GetUsageTrendWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(24*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with both filters
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
// --- GetModelStatsWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
log1 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
OutputTokens: 100,
|
||||
TotalCost: 0.2,
|
||||
ActualCost: 0.2,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with account filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
// --- GetAccountUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create logs on different days
|
||||
log1 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.4,
|
||||
CreatedAt: base.Add(12 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
OutputTokens: 100,
|
||||
TotalCost: 0.2,
|
||||
ActualCost: 0.15,
|
||||
CreatedAt: base.Add(36 * time.Hour), // next day
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
|
||||
startTime := base
|
||||
endTime := base.Add(72 * time.Hour)
|
||||
|
||||
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "GetAccountUsageStats")
|
||||
|
||||
s.Require().Len(resp.History, 2, "expected 2 days of history")
|
||||
s.Require().Equal(int64(2), resp.Summary.TotalRequests)
|
||||
s.Require().Equal(int64(450), resp.Summary.TotalTokens)
|
||||
s.Require().Len(resp.Models, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
startTime := base
|
||||
endTime := base.Add(72 * time.Hour)
|
||||
|
||||
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "GetAccountUsageStats empty")
|
||||
|
||||
s.Require().Len(resp.History, 0)
|
||||
s.Require().Equal(int64(0), resp.Summary.TotalRequests)
|
||||
}
|
||||
|
||||
// --- GetUserUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
|
||||
s.createUsageLog(user2, apiKey2, account, 50, 100, 0.5, base)
|
||||
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetUserUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetUserUsageTrend")
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
// --- GetApiKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
|
||||
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
|
||||
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend")
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
|
||||
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
// --- ListWithFilters (additional filter tests) ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters apiKey")
|
||||
s.Require().Len(logs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
filters := usagestats.UsageLogFilters{StartTime: &startTime, EndTime: &endTime}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters time range")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters combined")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
448
backend/internal/repository/user_repo_integration_test.go
Normal file
448
backend/internal/repository/user_repo_integration_test.go
Normal file
@@ -0,0 +1,448 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *UserRepository
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUserRepository(s.db)
|
||||
}
|
||||
|
||||
func TestUserRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByEmail / Update / Delete ---
|
||||
|
||||
func (s *UserRepoSuite) TestCreate() {
|
||||
user := &model.User{
|
||||
Email: "create@test.com",
|
||||
Username: "testuser",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, user)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(user.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("create@test.com", got.Email)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
|
||||
|
||||
got, err := s.repo.GetByEmail(s.ctx, user.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
s.Require().Equal(user.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
|
||||
s.Require().Error(err, "expected error for non-existent email")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
|
||||
|
||||
user.Username = "updated"
|
||||
err := s.repo.Update(s.ctx, user)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *UserRepoSuite) TestList() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
|
||||
|
||||
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(users, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(model.StatusActive, users[0].Status)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(model.RoleAdmin, users[0].Role)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Contains(users[0].Email, "alice")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal("JohnDoe", users[0].Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal("wx_hello", users[0].Wechat)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
|
||||
|
||||
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Len(users, 1, "expected 1 user")
|
||||
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
||||
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
|
||||
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Wechat: "wx_a",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
target := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Wechat: "wx_b",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "c@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusDisabled,
|
||||
})
|
||||
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
|
||||
// --- Balance operations ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
|
||||
s.Require().NoError(err, "UpdateBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(12.5, got.Balance)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
|
||||
s.Require().NoError(err, "UpdateBalance with negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(7.0, got.Balance)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
|
||||
s.Require().NoError(err, "DeductBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(5.0, got.Balance)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||
s.Require().Error(err, "expected error for insufficient balance")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "DeductBalance exact amount")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.Balance)
|
||||
}
|
||||
|
||||
// --- Concurrency ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
|
||||
s.Require().NoError(err, "UpdateConcurrency")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(8, got.Concurrency)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
|
||||
s.Require().NoError(err, "UpdateConcurrency negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(3, got.Concurrency)
|
||||
}
|
||||
|
||||
// --- ExistsByEmail ---
|
||||
|
||||
func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
|
||||
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
|
||||
s.Require().NoError(err, "ExistsByEmail")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- RemoveGroupFromAllowedGroups ---
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||
groupID := int64(42)
|
||||
userA := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "a1@example.com",
|
||||
AllowedGroups: pq.Int64Array{groupID, 7},
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "a2@example.com",
|
||||
AllowedGroups: pq.Int64Array{7},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
|
||||
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, userA.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
for _, id := range got.AllowedGroups {
|
||||
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "nomatch@test.com",
|
||||
AllowedGroups: pq.Int64Array{1, 2, 3},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected, "expected no affected rows")
|
||||
}
|
||||
|
||||
// --- GetFirstAdmin ---
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
admin1 := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "admin1@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "admin2@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().NoError(err, "GetFirstAdmin")
|
||||
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "user@example.com",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
_, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().Error(err, "expected error when no admin exists")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "disabled@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusDisabled,
|
||||
})
|
||||
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "active@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().NoError(err, "GetFirstAdmin")
|
||||
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Wechat: "wx_a",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Wechat: "wx_b",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
_ = mustCreateUser(s.T(), s.db, &model.User{
|
||||
Email: "c@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusDisabled,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
|
||||
|
||||
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
|
||||
|
||||
got.Username = "Alice2"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||
got2, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
|
||||
got3, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateBalance")
|
||||
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
|
||||
|
||||
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
|
||||
got4, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after DeductBalance")
|
||||
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
|
||||
|
||||
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
|
||||
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
|
||||
got5, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateConcurrency")
|
||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
|
||||
|
||||
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
@@ -0,0 +1,733 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *UserSubscriptionRepository
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUserSubscriptionRepository(s.db)
|
||||
}
|
||||
|
||||
func TestUserSubscriptionRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserSubscriptionRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
|
||||
|
||||
sub := &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, sub)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(sub.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(sub.UserID, got.UserID)
|
||||
s.Require().Equal(sub.GroupID, got.GroupID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
|
||||
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
|
||||
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
AssignedBy: &admin.ID,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
sub.Notes = "updated notes"
|
||||
err := s.repo.Update(s.ctx, sub)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated notes", got.Notes)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
err := s.repo.Delete(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "GetByUserIDAndGroupID")
|
||||
s.Require().Equal(sub.ID, got.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
|
||||
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent pair")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
|
||||
|
||||
// Create active subscription (future expiry)
|
||||
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||
})
|
||||
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||
s.Require().Equal(active.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
|
||||
|
||||
// Create expired subscription (past expiry but active status)
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
})
|
||||
|
||||
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().Error(err, "expected error for expired subscription")
|
||||
}
|
||||
|
||||
// --- ListByUserID / ListActiveByUserID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(subs, 2)
|
||||
for _, sub := range subs {
|
||||
s.Require().NotNil(sub.Group, "expected Group preload")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "ListActiveByUserID")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
|
||||
}
|
||||
|
||||
// --- ListByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(subs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
for _, sub := range subs {
|
||||
s.Require().NotNil(sub.User, "expected User preload")
|
||||
s.Require().NotNil(sub.Group, "expected Group preload")
|
||||
}
|
||||
}
|
||||
|
||||
// --- List with filters ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
|
||||
}
|
||||
|
||||
// --- Usage tracking ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
|
||||
s.Require().NoError(err, "IncrementUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1.25, got.DailyUsageUSD)
|
||||
s.Require().Equal(1.25, got.WeeklyUsageUSD)
|
||||
s.Require().Equal(1.25, got.MonthlyUsageUSD)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(3.5, got.DailyUsageUSD)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
|
||||
s.Require().NoError(err, "ActivateWindows")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.DailyWindowStart)
|
||||
s.Require().NotNil(got.WeeklyWindowStart)
|
||||
s.Require().NotNil(got.MonthlyWindowStart)
|
||||
s.Require().True(got.DailyWindowStart.Equal(activateAt))
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
DailyUsageUSD: 10.0,
|
||||
WeeklyUsageUSD: 20.0,
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
|
||||
s.Require().NoError(err, "ResetDailyUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.DailyUsageUSD)
|
||||
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
|
||||
s.Require().True(got.DailyWindowStart.Equal(resetAt))
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
WeeklyUsageUSD: 15.0,
|
||||
MonthlyUsageUSD: 30.0,
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
|
||||
s.Require().NoError(err, "ResetWeeklyUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.WeeklyUsageUSD)
|
||||
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
|
||||
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
MonthlyUsageUSD: 100.0,
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
|
||||
s.Require().NoError(err, "ResetMonthlyUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.MonthlyUsageUSD)
|
||||
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
|
||||
}
|
||||
|
||||
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
s.Require().NoError(err, "UpdateStatus")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
|
||||
s.Require().NoError(err, "ExtendExpiry")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(got.ExpiresAt.Equal(newExpiry))
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
|
||||
s.Require().NoError(err, "UpdateNotes")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("VIP user", got.Notes)
|
||||
}
|
||||
|
||||
// --- ListExpired / BatchUpdateExpiredStatus ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
expired, err := s.repo.ListExpired(s.ctx)
|
||||
s.Require().NoError(err, "ListExpired")
|
||||
s.Require().Len(expired, 1)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
|
||||
|
||||
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||
s.Require().Equal(int64(1), affected)
|
||||
|
||||
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
|
||||
|
||||
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
|
||||
}
|
||||
|
||||
// --- ExistsByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- CountByGroupID / CountActiveByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
|
||||
})
|
||||
|
||||
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountActiveByGroupID")
|
||||
s.Require().Equal(int64(1), count, "only future expiry counts as active")
|
||||
}
|
||||
|
||||
// --- DeleteByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "DeleteByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
|
||||
|
||||
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||
})
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
})
|
||||
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||
s.Require().Equal(active.ID, got.ID, "expected active subscription")
|
||||
|
||||
activateAt := time.Now().Add(-25 * time.Hour)
|
||||
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
|
||||
|
||||
after, err := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
|
||||
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
|
||||
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
|
||||
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
|
||||
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
|
||||
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
|
||||
|
||||
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
|
||||
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
|
||||
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().NoError(err, "GetByID after reset")
|
||||
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
|
||||
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
|
||||
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
|
||||
|
||||
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().NoError(err, "GetByID expired")
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
@@ -180,6 +180,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
@@ -192,6 +193,8 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
|
||||
@@ -388,12 +388,14 @@ func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
// OAuth accounts using ChatGPT internal API require store: false and instructions
|
||||
// OAuth accounts using ChatGPT internal API require store: false
|
||||
if isOAuth {
|
||||
payload["store"] = false
|
||||
payload["instructions"] = openai.DefaultInstructions
|
||||
}
|
||||
|
||||
// All accounts require instructions for Responses API
|
||||
payload["instructions"] = openai.DefaultInstructions
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ type AdminService interface {
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
|
||||
@@ -140,6 +141,33 @@ type UpdateAccountInput struct {
|
||||
GroupIDs *[]int64
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
type BulkUpdateAccountResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
Results []BulkUpdateAccountResult `json:"results"`
|
||||
}
|
||||
|
||||
type CreateProxyInput struct {
|
||||
Name string
|
||||
Protocol string
|
||||
@@ -694,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
result := &BulkUpdateAccountsResult{
|
||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||
}
|
||||
|
||||
if len(input.AccountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := ports.AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
Extra: input.Extra,
|
||||
}
|
||||
if input.Name != "" {
|
||||
repoUpdates.Name = &input.Name
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
repoUpdates.ProxyID = input.ProxyID
|
||||
}
|
||||
if input.Concurrency != nil {
|
||||
repoUpdates.Concurrency = input.Concurrency
|
||||
}
|
||||
if input.Priority != nil {
|
||||
repoUpdates.Priority = input.Priority
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
|
||||
// Run bulk update for column/jsonb fields first.
|
||||
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle group bindings per account (requires individual operations).
|
||||
for _, accountID := range input.AccountIDs {
|
||||
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
entry.Success = true
|
||||
result.Success++
|
||||
result.Results = append(result.Results, entry)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
961
backend/internal/service/crs_sync_service.go
Normal file
961
backend/internal/service/crs_sync_service.go
Normal file
@@ -0,0 +1,961 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
) *CRSSyncService {
|
||||
return &CRSSyncService{
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
type SyncFromCRSInput struct {
|
||||
BaseURL string
|
||||
Username string
|
||||
Password string
|
||||
SyncProxies bool
|
||||
}
|
||||
|
||||
type SyncFromCRSItemResult struct {
|
||||
CRSAccountID string `json:"crs_account_id"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Action string `json:"action"` // created/updated/failed/skipped
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type SyncFromCRSResult struct {
|
||||
Created int `json:"created"`
|
||||
Updated int `json:"updated"`
|
||||
Skipped int `json:"skipped"`
|
||||
Failed int `json:"failed"`
|
||||
Items []SyncFromCRSItemResult `json:"items"`
|
||||
}
|
||||
|
||||
type crsLoginResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Token string `json:"token"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type crsExportResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
ExportedAt string `json:"exportedAt"`
|
||||
ClaudeAccounts []crsClaudeAccount `json:"claudeAccounts"`
|
||||
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
|
||||
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
|
||||
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type crsProxy struct {
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type crsClaudeAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
AuthType string `json:"authType"` // oauth/setup-token
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
type crsConsoleAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
MaxConcurrentTasks int `json:"maxConcurrentTasks"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
}
|
||||
|
||||
type crsOpenAIResponsesAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
}
|
||||
|
||||
type crsOpenAIOAuthAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
AuthType string `json:"authType"` // oauth
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
result := &SyncFromCRSResult{
|
||||
Items: make(
|
||||
[]SyncFromCRSItemResult,
|
||||
0,
|
||||
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
|
||||
),
|
||||
}
|
||||
|
||||
var proxies []model.Proxy
|
||||
if input.SyncProxies {
|
||||
proxies, _ = s.proxyRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
// Claude OAuth / Setup Token -> sub2api anthropic oauth/setup-token
|
||||
for _, src := range exported.Data.ClaudeAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
targetType := strings.TrimSpace(src.AuthType)
|
||||
if targetType == "" {
|
||||
targetType = "oauth"
|
||||
}
|
||||
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
|
||||
item.Action = "skipped"
|
||||
item.Error = "unsupported authType: " + targetType
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
accessToken, _ := src.Credentials["access_token"].(string)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing access_token"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
// 🔧 Remove /v1 suffix from base_url for Claude accounts
|
||||
cleanBaseURL(credentials, "/v1")
|
||||
// 🔧 Convert expires_at from ISO string to Unix timestamp
|
||||
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
credentials["expires_at"] = t.Unix()
|
||||
}
|
||||
}
|
||||
// 🔧 Add intercept_warmup_requests if not present (defaults to false)
|
||||
if _, exists := credentials["intercept_warmup_requests"]; !exists {
|
||||
credentials["intercept_warmup_requests"] = false
|
||||
}
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
// 🔧 Preserve all CRS extra fields and add sync metadata
|
||||
extra := make(map[string]any)
|
||||
if src.Extra != nil {
|
||||
for k, v := range src.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
extra["crs_account_id"] = src.ID
|
||||
extra["crs_kind"] = src.Kind
|
||||
extra["crs_synced_at"] = now
|
||||
// Extract org_uuid and account_uuid from CRS credentials to extra
|
||||
if orgUUID, ok := src.Credentials["org_uuid"]; ok {
|
||||
extra["org_uuid"] = orgUUID
|
||||
}
|
||||
if accountUUID, ok := src.Credentials["account_uuid"]; ok {
|
||||
extra["account_uuid"] = accountUUID
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: targetType,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if targetType == model.AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
}
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// Update existing
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformAnthropic
|
||||
existing.Type = targetType
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if targetType == model.AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
}
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// Claude Console API Key -> sub2api anthropic apikey
|
||||
for _, src := range exported.Data.ClaudeConsoleAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
apiKey, _ := src.Credentials["api_key"].(string)
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing api_key"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
if src.MaxConcurrentTasks > 0 {
|
||||
concurrency = src.MaxConcurrentTasks
|
||||
}
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
extra := map[string]any{
|
||||
"crs_account_id": src.ID,
|
||||
"crs_kind": src.Kind,
|
||||
"crs_synced_at": now,
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformAnthropic
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// OpenAI OAuth -> sub2api openai oauth
|
||||
for _, src := range exported.Data.OpenAIOAuthAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
accessToken, _ := src.Credentials["access_token"].(string)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing access_token"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(
|
||||
ctx,
|
||||
input.SyncProxies,
|
||||
&proxies,
|
||||
src.Proxy,
|
||||
fmt.Sprintf("crs-%s", src.Name),
|
||||
)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
// Normalize token_type
|
||||
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
|
||||
credentials["token_type"] = "Bearer"
|
||||
}
|
||||
// 🔧 Convert expires_at from ISO string to Unix timestamp
|
||||
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
credentials["expires_at"] = t.Unix()
|
||||
}
|
||||
}
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
// 🔧 Preserve all CRS extra fields and add sync metadata
|
||||
extra := make(map[string]any)
|
||||
if src.Extra != nil {
|
||||
for k, v := range src.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
extra["crs_account_id"] = src.ID
|
||||
extra["crs_kind"] = src.Kind
|
||||
extra["crs_synced_at"] = now
|
||||
// Extract email from CRS extra (crs_email -> email)
|
||||
if crsEmail, ok := src.Extra["crs_email"]; ok {
|
||||
extra["email"] = crsEmail
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformOpenAI,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformOpenAI
|
||||
existing.Type = model.AccountTypeOAuth
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// OpenAI Responses API Key -> sub2api openai apikey
|
||||
for _, src := range exported.Data.OpenAIResponsesAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
apiKey, _ := src.Credentials["api_key"].(string)
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing api_key"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if baseURL, ok := src.Credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
|
||||
src.Credentials["base_url"] = "https://api.openai.com"
|
||||
}
|
||||
// 🔧 Remove /v1 suffix from base_url for OpenAI accounts
|
||||
cleanBaseURL(src.Credentials, "/v1")
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(
|
||||
ctx,
|
||||
input.SyncProxies,
|
||||
&proxies,
|
||||
src.Proxy,
|
||||
fmt.Sprintf("crs-%s", src.Name),
|
||||
)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
extra := map[string]any{
|
||||
"crs_account_id": src.ID,
|
||||
"crs_kind": src.Kind,
|
||||
"crs_synced_at": now,
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformOpenAI,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformOpenAI
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
|
||||
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
|
||||
out := make(model.JSONB)
|
||||
for k, v := range existing {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range updates {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
|
||||
if !enabled || src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
protocol := strings.ToLower(strings.TrimSpace(src.Protocol))
|
||||
switch protocol {
|
||||
case "socks":
|
||||
protocol = "socks5"
|
||||
case "socks5h":
|
||||
protocol = "socks5"
|
||||
}
|
||||
host := strings.TrimSpace(src.Host)
|
||||
port := src.Port
|
||||
username := strings.TrimSpace(src.Username)
|
||||
password := strings.TrimSpace(src.Password)
|
||||
|
||||
if protocol == "" || host == "" || port <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if protocol != "http" && protocol != "https" && protocol != "socks5" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Find existing proxy (active only).
|
||||
for _, p := range *cached {
|
||||
if strings.EqualFold(p.Protocol, protocol) &&
|
||||
p.Host == host &&
|
||||
p.Port == port &&
|
||||
p.Username == username &&
|
||||
p.Password == password {
|
||||
id := p.ID
|
||||
return &id, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new proxy
|
||||
proxy := &model.Proxy{
|
||||
Name: defaultProxyName(defaultName, protocol, host, port),
|
||||
Protocol: protocol,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
*cached = append(*cached, *proxy)
|
||||
id := proxy.ID
|
||||
return &id, nil
|
||||
}
|
||||
|
||||
func defaultProxyName(base, protocol, host string, port int) string {
|
||||
base = strings.TrimSpace(base)
|
||||
if base == "" {
|
||||
base = "crs"
|
||||
}
|
||||
return fmt.Sprintf("%s (%s://%s:%d)", base, protocol, host, port)
|
||||
}
|
||||
|
||||
func defaultName(name, id string) string {
|
||||
if strings.TrimSpace(name) != "" {
|
||||
return strings.TrimSpace(name)
|
||||
}
|
||||
return "CRS " + id
|
||||
}
|
||||
|
||||
func clampPriority(priority int) int {
|
||||
if priority < 1 || priority > 100 {
|
||||
return 50
|
||||
}
|
||||
return priority
|
||||
}
|
||||
|
||||
func sanitizeCredentialsMap(input map[string]any) map[string]any {
|
||||
if input == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(input))
|
||||
for k, v := range input {
|
||||
// Avoid nil values to keep JSONB cleaner
|
||||
if v != nil {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapCRSStatus(isActive bool, status string) string {
|
||||
if !isActive {
|
||||
return "inactive"
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(status), "error") {
|
||||
return "error"
|
||||
}
|
||||
return "active"
|
||||
}
|
||||
|
||||
func normalizeBaseURL(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("base_url is required")
|
||||
}
|
||||
u, err := url.Parse(trimmed)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "", fmt.Errorf("invalid base_url: %s", trimmed)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/")
|
||||
return strings.TrimRight(u.String(), "/"), nil
|
||||
}
|
||||
|
||||
// cleanBaseURL removes trailing suffix from base_url in credentials
|
||||
// Used for both Claude and OpenAI accounts to remove /v1
|
||||
func cleanBaseURL(credentials map[string]any, suffixToRemove string) {
|
||||
if baseURL, ok := credentials["base_url"].(string); ok && baseURL != "" {
|
||||
trimmed := strings.TrimSpace(baseURL)
|
||||
if strings.HasSuffix(trimmed, suffixToRemove) {
|
||||
credentials["base_url"] = strings.TrimSuffix(trimmed, suffixToRemove)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func crsLogin(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/web/auth/login", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("crs login failed: status=%d body=%s", resp.StatusCode, string(raw))
|
||||
}
|
||||
|
||||
var parsed crsLoginResponse
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return "", fmt.Errorf("crs login parse failed: %w", err)
|
||||
}
|
||||
if !parsed.Success || strings.TrimSpace(parsed.Token) == "" {
|
||||
msg := parsed.Message
|
||||
if msg == "" {
|
||||
msg = parsed.Error
|
||||
}
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return "", errors.New("crs login failed: " + msg)
|
||||
}
|
||||
return parsed.Token, nil
|
||||
}
|
||||
|
||||
func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminToken string) (*crsExportResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/admin/sync/export-accounts?include_secrets=true", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 5<<20))
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("crs export failed: status=%d body=%s", resp.StatusCode, string(raw))
|
||||
}
|
||||
|
||||
var parsed crsExportResponse
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("crs export parse failed: %w", err)
|
||||
}
|
||||
if !parsed.Success {
|
||||
msg := parsed.Message
|
||||
if msg == "" {
|
||||
msg = parsed.Error
|
||||
}
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return nil, errors.New("crs export failed: " + msg)
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
// refreshOAuthToken attempts to refresh OAuth token for a synced account
|
||||
// Returns updated credentials or nil if refresh failed/not applicable
|
||||
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
|
||||
if account.Type != model.AccountTypeOAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
var err error
|
||||
|
||||
switch account.Platform {
|
||||
case model.PlatformAnthropic:
|
||||
if s.oauthService == nil {
|
||||
return nil
|
||||
}
|
||||
tokenInfo, refreshErr := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else {
|
||||
// Preserve existing credentials
|
||||
newCredentials = make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
// Update token fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
}
|
||||
case model.PlatformOpenAI:
|
||||
if s.openaiOAuthService == nil {
|
||||
return nil
|
||||
}
|
||||
tokenInfo, refreshErr := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else {
|
||||
newCredentials = s.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log but don't fail the sync - token might still be valid or refreshable later
|
||||
return nil
|
||||
}
|
||||
|
||||
return model.JSONB(newCredentials)
|
||||
}
|
||||
@@ -20,6 +20,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -360,8 +362,8 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
|
||||
|
||||
// 重试相关常量
|
||||
const (
|
||||
maxRetries = 3 // 最大重试次数
|
||||
retryDelay = 2 * time.Second // 重试等待时间
|
||||
maxRetries = 5 // 最大重试次数
|
||||
retryDelay = 6 * time.Second // 重试等待时间
|
||||
)
|
||||
|
||||
// shouldRetryUpstreamError 判断是否应该重试上游错误
|
||||
@@ -390,6 +392,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
|
||||
if !gjson.GetBytes(body, "system").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "system", []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
@@ -622,6 +636,9 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
var statusCode int
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 400:
|
||||
c.Data(http.StatusBadRequest, "application/json", body)
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
|
||||
@@ -827,6 +827,112 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
}
|
||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||
|
||||
// Normalize to canonical 5h/7d fields based on window_minutes
|
||||
// This fixes the issue where OpenAI's primary/secondary naming is reversed
|
||||
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
|
||||
|
||||
// IMPORTANT: We can only reliably determine window type from window_minutes field
|
||||
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
|
||||
|
||||
var primaryWindowMins, secondaryWindowMins int
|
||||
var hasPrimaryWindow, hasSecondaryWindow bool
|
||||
|
||||
// Only use window_minutes for reliable window size comparison
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
primaryWindowMins = *snapshot.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine which is 5h and which is 7d
|
||||
var use5hFromPrimary, use7dFromPrimary bool
|
||||
var use5hFromSecondary, use7dFromSecondary bool
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
|
||||
if primaryWindowMins < secondaryWindowMins {
|
||||
use5hFromPrimary = true
|
||||
use7dFromSecondary = true
|
||||
} else {
|
||||
use5hFromSecondary = true
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary window size known: classify by absolute threshold
|
||||
if primaryWindowMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary window size known: classify by absolute threshold
|
||||
if secondaryWindowMins <= 360 {
|
||||
use5hFromSecondary = true
|
||||
} else {
|
||||
use7dFromSecondary = true
|
||||
}
|
||||
} else {
|
||||
// No window_minutes available: cannot reliably determine window types
|
||||
// Fall back to legacy assumption (may be incorrect)
|
||||
// Assume primary=7d, secondary=5h based on historical observation
|
||||
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
|
||||
use5hFromSecondary = true
|
||||
}
|
||||
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 5h fields
|
||||
if use5hFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use5hFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 7d fields
|
||||
if use7dFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use7dFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
}
|
||||
|
||||
// Update account's Extra field asynchronously
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
@@ -11,6 +11,9 @@ import (
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
@@ -35,4 +38,17 @@ type AccountRepository interface {
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionService,
|
||||
NewConcurrencyService,
|
||||
NewIdentityService,
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
|
||||
|
||||
@@ -114,6 +114,8 @@ services:
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
ports:
|
||||
- 5433:5432
|
||||
|
||||
# ===========================================================================
|
||||
# Redis Cache
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
* Handles AI platform account management for administrators
|
||||
*/
|
||||
|
||||
import { apiClient } from '../client';
|
||||
import { apiClient } from '../client'
|
||||
import type {
|
||||
Account,
|
||||
CreateAccountRequest,
|
||||
@@ -13,7 +13,7 @@ import type {
|
||||
WindowStats,
|
||||
ClaudeModel,
|
||||
AccountUsageStatsResponse,
|
||||
} from '@/types';
|
||||
} from '@/types'
|
||||
|
||||
/**
|
||||
* List all accounts with pagination
|
||||
@@ -26,10 +26,10 @@ export async function list(
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
filters?: {
|
||||
platform?: string;
|
||||
type?: string;
|
||||
status?: string;
|
||||
search?: string;
|
||||
platform?: string
|
||||
type?: string
|
||||
status?: string
|
||||
search?: string
|
||||
}
|
||||
): Promise<PaginatedResponse<Account>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<Account>>('/admin/accounts', {
|
||||
@@ -38,8 +38,8 @@ export async function list(
|
||||
page_size: pageSize,
|
||||
...filters,
|
||||
},
|
||||
});
|
||||
return data;
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -48,8 +48,8 @@ export async function list(
|
||||
* @returns Account details
|
||||
*/
|
||||
export async function getById(id: number): Promise<Account> {
|
||||
const { data } = await apiClient.get<Account>(`/admin/accounts/${id}`);
|
||||
return data;
|
||||
const { data } = await apiClient.get<Account>(`/admin/accounts/${id}`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -58,8 +58,8 @@ export async function getById(id: number): Promise<Account> {
|
||||
* @returns Created account
|
||||
*/
|
||||
export async function create(accountData: CreateAccountRequest): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>('/admin/accounts', accountData);
|
||||
return data;
|
||||
const { data } = await apiClient.post<Account>('/admin/accounts', accountData)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,8 +69,8 @@ export async function create(accountData: CreateAccountRequest): Promise<Account
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function update(id: number, updates: UpdateAccountRequest): Promise<Account> {
|
||||
const { data } = await apiClient.put<Account>(`/admin/accounts/${id}`, updates);
|
||||
return data;
|
||||
const { data } = await apiClient.put<Account>(`/admin/accounts/${id}`, updates)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -79,8 +79,8 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise
|
||||
* @returns Success confirmation
|
||||
*/
|
||||
export async function deleteAccount(id: number): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.delete<{ message: string }>(`/admin/accounts/${id}`);
|
||||
return data;
|
||||
const { data } = await apiClient.delete<{ message: string }>(`/admin/accounts/${id}`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -89,11 +89,8 @@ export async function deleteAccount(id: number): Promise<{ message: string }> {
|
||||
* @param status - New status
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function toggleStatus(
|
||||
id: number,
|
||||
status: 'active' | 'inactive'
|
||||
): Promise<Account> {
|
||||
return update(id, { status });
|
||||
export async function toggleStatus(id: number, status: 'active' | 'inactive'): Promise<Account> {
|
||||
return update(id, { status })
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -102,16 +99,16 @@ export async function toggleStatus(
|
||||
* @returns Test result
|
||||
*/
|
||||
export async function testAccount(id: number): Promise<{
|
||||
success: boolean;
|
||||
message: string;
|
||||
latency_ms?: number;
|
||||
success: boolean
|
||||
message: string
|
||||
latency_ms?: number
|
||||
}> {
|
||||
const { data } = await apiClient.post<{
|
||||
success: boolean;
|
||||
message: string;
|
||||
latency_ms?: number;
|
||||
}>(`/admin/accounts/${id}/test`);
|
||||
return data;
|
||||
success: boolean
|
||||
message: string
|
||||
latency_ms?: number
|
||||
}>(`/admin/accounts/${id}/test`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -120,8 +117,8 @@ export async function testAccount(id: number): Promise<{
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function refreshCredentials(id: number): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/refresh`);
|
||||
return data;
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/refresh`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -133,8 +130,8 @@ export async function refreshCredentials(id: number): Promise<Account> {
|
||||
export async function getStats(id: number, days: number = 30): Promise<AccountUsageStatsResponse> {
|
||||
const { data } = await apiClient.get<AccountUsageStatsResponse>(`/admin/accounts/${id}/stats`, {
|
||||
params: { days },
|
||||
});
|
||||
return data;
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -143,8 +140,8 @@ export async function getStats(id: number, days: number = 30): Promise<AccountUs
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function clearError(id: number): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/clear-error`);
|
||||
return data;
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/clear-error`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -153,8 +150,8 @@ export async function clearError(id: number): Promise<Account> {
|
||||
* @returns Account usage info
|
||||
*/
|
||||
export async function getUsage(id: number): Promise<AccountUsageInfo> {
|
||||
const { data } = await apiClient.get<AccountUsageInfo>(`/admin/accounts/${id}/usage`);
|
||||
return data;
|
||||
const { data } = await apiClient.get<AccountUsageInfo>(`/admin/accounts/${id}/usage`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -163,8 +160,10 @@ export async function getUsage(id: number): Promise<AccountUsageInfo> {
|
||||
* @returns Success confirmation
|
||||
*/
|
||||
export async function clearRateLimit(id: number): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.post<{ message: string }>(`/admin/accounts/${id}/clear-rate-limit`);
|
||||
return data;
|
||||
const { data } = await apiClient.post<{ message: string }>(
|
||||
`/admin/accounts/${id}/clear-rate-limit`
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -177,8 +176,8 @@ export async function generateAuthUrl(
|
||||
endpoint: string,
|
||||
config: { proxy_id?: number }
|
||||
): Promise<{ auth_url: string; session_id: string }> {
|
||||
const { data } = await apiClient.post<{ auth_url: string; session_id: string }>(endpoint, config);
|
||||
return data;
|
||||
const { data } = await apiClient.post<{ auth_url: string; session_id: string }>(endpoint, config)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -191,8 +190,8 @@ export async function exchangeCode(
|
||||
endpoint: string,
|
||||
exchangeData: { session_id: string; code: string; proxy_id?: number }
|
||||
): Promise<Record<string, unknown>> {
|
||||
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData);
|
||||
return data;
|
||||
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -201,16 +200,63 @@ export async function exchangeCode(
|
||||
* @returns Results of batch creation
|
||||
*/
|
||||
export async function batchCreate(accounts: CreateAccountRequest[]): Promise<{
|
||||
success: number;
|
||||
failed: number;
|
||||
results: Array<{ success: boolean; account?: Account; error?: string }>;
|
||||
success: number
|
||||
failed: number
|
||||
results: Array<{ success: boolean; account?: Account; error?: string }>
|
||||
}> {
|
||||
const { data } = await apiClient.post<{
|
||||
success: number;
|
||||
failed: number;
|
||||
results: Array<{ success: boolean; account?: Account; error?: string }>;
|
||||
}>('/admin/accounts/batch', { accounts });
|
||||
return data;
|
||||
success: number
|
||||
failed: number
|
||||
results: Array<{ success: boolean; account?: Account; error?: string }>
|
||||
}>('/admin/accounts/batch', { accounts })
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch update credentials fields for multiple accounts
|
||||
* @param request - Batch update request containing account IDs, field name, and value
|
||||
* @returns Results of batch update
|
||||
*/
|
||||
export async function batchUpdateCredentials(request: {
|
||||
account_ids: number[]
|
||||
field: string
|
||||
value: any
|
||||
}): Promise<{
|
||||
success: number
|
||||
failed: number
|
||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||
}> {
|
||||
const { data } = await apiClient.post<{
|
||||
success: number
|
||||
failed: number
|
||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||
}>('/admin/accounts/batch-update-credentials', request)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Bulk update multiple accounts
|
||||
* @param accountIds - Array of account IDs
|
||||
* @param updates - Fields to update
|
||||
* @returns Success confirmation
|
||||
*/
|
||||
export async function bulkUpdate(
|
||||
accountIds: number[],
|
||||
updates: Record<string, unknown>
|
||||
): Promise<{
|
||||
success: number
|
||||
failed: number
|
||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||
}> {
|
||||
const { data } = await apiClient.post<{
|
||||
success: number
|
||||
failed: number
|
||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||
}>('/admin/accounts/bulk-update', {
|
||||
account_ids: accountIds,
|
||||
...updates,
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -219,8 +265,8 @@ export async function batchCreate(accounts: CreateAccountRequest[]): Promise<{
|
||||
* @returns Today's stats (requests, tokens, cost)
|
||||
*/
|
||||
export async function getTodayStats(id: number): Promise<WindowStats> {
|
||||
const { data } = await apiClient.get<WindowStats>(`/admin/accounts/${id}/today-stats`);
|
||||
return data;
|
||||
const { data } = await apiClient.get<WindowStats>(`/admin/accounts/${id}/today-stats`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -230,8 +276,10 @@ export async function getTodayStats(id: number): Promise<WindowStats> {
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function setSchedulable(id: number, schedulable: boolean): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/schedulable`, { schedulable });
|
||||
return data;
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/schedulable`, {
|
||||
schedulable,
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -240,8 +288,30 @@ export async function setSchedulable(id: number, schedulable: boolean): Promise<
|
||||
* @returns List of available models for this account
|
||||
*/
|
||||
export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
|
||||
const { data } = await apiClient.get<ClaudeModel[]>(`/admin/accounts/${id}/models`);
|
||||
return data;
|
||||
const { data } = await apiClient.get<ClaudeModel[]>(`/admin/accounts/${id}/models`)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function syncFromCrs(params: {
|
||||
base_url: string
|
||||
username: string
|
||||
password: string
|
||||
sync_proxies?: boolean
|
||||
}): Promise<{
|
||||
created: number
|
||||
updated: number
|
||||
skipped: number
|
||||
failed: number
|
||||
items: Array<{
|
||||
crs_account_id: string
|
||||
kind: string
|
||||
name: string
|
||||
action: string
|
||||
error?: string
|
||||
}>
|
||||
}> {
|
||||
const { data } = await apiClient.post('/admin/accounts/sync/crs', params)
|
||||
return data
|
||||
}
|
||||
|
||||
export const accountsAPI = {
|
||||
@@ -263,6 +333,9 @@ export const accountsAPI = {
|
||||
generateAuthUrl,
|
||||
exchangeCode,
|
||||
batchCreate,
|
||||
};
|
||||
batchUpdateCredentials,
|
||||
bulkUpdate,
|
||||
syncFromCrs,
|
||||
}
|
||||
|
||||
export default accountsAPI;
|
||||
export default accountsAPI
|
||||
|
||||
@@ -69,21 +69,21 @@
|
||||
<!-- OpenAI OAuth accounts: show Codex usage from extra field -->
|
||||
<template v-else-if="account.platform === 'openai' && account.type === 'oauth'">
|
||||
<div v-if="hasCodexUsage" class="space-y-1">
|
||||
<!-- 5h Window (Secondary) -->
|
||||
<!-- 5h Window -->
|
||||
<UsageProgressBar
|
||||
v-if="codexSecondaryUsedPercent !== null"
|
||||
v-if="codex5hUsedPercent !== null"
|
||||
label="5h"
|
||||
:utilization="codexSecondaryUsedPercent"
|
||||
:resets-at="codexSecondaryResetAt"
|
||||
:utilization="codex5hUsedPercent"
|
||||
:resets-at="codex5hResetAt"
|
||||
color="indigo"
|
||||
/>
|
||||
|
||||
<!-- Weekly Window (Primary) -->
|
||||
<!-- 7d Window -->
|
||||
<UsageProgressBar
|
||||
v-if="codexPrimaryUsedPercent !== null"
|
||||
v-if="codex7dUsedPercent !== null"
|
||||
label="7d"
|
||||
:utilization="codexPrimaryUsedPercent"
|
||||
:resets-at="codexPrimaryResetAt"
|
||||
:utilization="codex7dUsedPercent"
|
||||
:resets-at="codex7dResetAt"
|
||||
color="emerald"
|
||||
/>
|
||||
</div>
|
||||
@@ -125,35 +125,123 @@ const showUsageWindows = computed(() =>
|
||||
const hasCodexUsage = computed(() => {
|
||||
const extra = props.account.extra
|
||||
return extra && (
|
||||
// Check for new canonical fields first
|
||||
extra.codex_5h_used_percent !== undefined ||
|
||||
extra.codex_7d_used_percent !== undefined ||
|
||||
// Fallback to legacy fields
|
||||
extra.codex_primary_used_percent !== undefined ||
|
||||
extra.codex_secondary_used_percent !== undefined
|
||||
)
|
||||
})
|
||||
|
||||
const codexPrimaryUsedPercent = computed(() => {
|
||||
// 5h window usage (prefer canonical field)
|
||||
const codex5hUsedPercent = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_primary_used_percent === undefined) return null
|
||||
return extra.codex_primary_used_percent
|
||||
if (!extra) return null
|
||||
|
||||
// Prefer canonical field
|
||||
if (extra.codex_5h_used_percent !== undefined) {
|
||||
return extra.codex_5h_used_percent
|
||||
}
|
||||
|
||||
// Fallback: detect from legacy fields using window_minutes
|
||||
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes <= 360) {
|
||||
return extra.codex_primary_used_percent ?? null
|
||||
}
|
||||
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes <= 360) {
|
||||
return extra.codex_secondary_used_percent ?? null
|
||||
}
|
||||
|
||||
// Legacy assumption: secondary = 5h (may be incorrect)
|
||||
return extra.codex_secondary_used_percent ?? null
|
||||
})
|
||||
|
||||
const codexSecondaryUsedPercent = computed(() => {
|
||||
const codex5hResetAt = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_secondary_used_percent === undefined) return null
|
||||
return extra.codex_secondary_used_percent
|
||||
if (!extra) return null
|
||||
|
||||
// Prefer canonical field
|
||||
if (extra.codex_5h_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_5h_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
|
||||
// Fallback: detect from legacy fields using window_minutes
|
||||
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes <= 360) {
|
||||
if (extra.codex_primary_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
}
|
||||
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes <= 360) {
|
||||
if (extra.codex_secondary_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy assumption: secondary = 5h
|
||||
if (extra.codex_secondary_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
|
||||
return null
|
||||
})
|
||||
|
||||
const codexPrimaryResetAt = computed(() => {
|
||||
// 7d window usage (prefer canonical field)
|
||||
const codex7dUsedPercent = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_primary_reset_after_seconds === undefined) return null
|
||||
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
if (!extra) return null
|
||||
|
||||
// Prefer canonical field
|
||||
if (extra.codex_7d_used_percent !== undefined) {
|
||||
return extra.codex_7d_used_percent
|
||||
}
|
||||
|
||||
// Fallback: detect from legacy fields using window_minutes
|
||||
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes >= 10000) {
|
||||
return extra.codex_primary_used_percent ?? null
|
||||
}
|
||||
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes >= 10000) {
|
||||
return extra.codex_secondary_used_percent ?? null
|
||||
}
|
||||
|
||||
// Legacy assumption: primary = 7d (may be incorrect)
|
||||
return extra.codex_primary_used_percent ?? null
|
||||
})
|
||||
|
||||
const codexSecondaryResetAt = computed(() => {
|
||||
const codex7dResetAt = computed(() => {
|
||||
const extra = props.account.extra
|
||||
if (!extra || extra.codex_secondary_reset_after_seconds === undefined) return null
|
||||
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
if (!extra) return null
|
||||
|
||||
// Prefer canonical field
|
||||
if (extra.codex_7d_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_7d_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
|
||||
// Fallback: detect from legacy fields using window_minutes
|
||||
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes >= 10000) {
|
||||
if (extra.codex_primary_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
}
|
||||
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes >= 10000) {
|
||||
if (extra.codex_secondary_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy assumption: primary = 7d
|
||||
if (extra.codex_primary_reset_after_seconds !== undefined) {
|
||||
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
|
||||
return resetTime.toISOString()
|
||||
}
|
||||
|
||||
return null
|
||||
})
|
||||
|
||||
const loadUsage = async () => {
|
||||
|
||||
892
frontend/src/components/account/BulkEditAccountModal.vue
Normal file
892
frontend/src/components/account/BulkEditAccountModal.vue
Normal file
@@ -0,0 +1,892 @@
|
||||
<template>
|
||||
<Modal :show="show" :title="t('admin.accounts.bulkEdit.title')" size="lg" @close="handleClose">
|
||||
<form class="space-y-5" @submit.prevent="handleSubmit">
|
||||
<!-- Info -->
|
||||
<div class="rounded-lg bg-blue-50 dark:bg-blue-900/20 p-4">
|
||||
<p class="text-sm text-blue-700 dark:text-blue-400">
|
||||
<svg class="w-5 h-5 inline mr-1.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Base URL (API Key only) -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.baseUrl') }}</label>
|
||||
<input
|
||||
v-model="enableBaseUrl"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<input
|
||||
v-model="baseUrl"
|
||||
type="text"
|
||||
:disabled="!enableBaseUrl"
|
||||
class="input"
|
||||
:class="!enableBaseUrl && 'opacity-50 cursor-not-allowed'"
|
||||
:placeholder="t('admin.accounts.bulkEdit.baseUrlPlaceholder')"
|
||||
/>
|
||||
<p class="input-hint">
|
||||
{{ t('admin.accounts.bulkEdit.baseUrlNotice') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model restriction -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
<input
|
||||
v-model="enableModelRestriction"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div :class="!enableModelRestriction && 'opacity-50 pointer-events-none'">
|
||||
<!-- Mode Toggle -->
|
||||
<div class="flex gap-2 mb-4">
|
||||
<button
|
||||
type="button"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500',
|
||||
]"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 inline mr-1.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500',
|
||||
]"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 inline mr-1.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M8 7h12m0 0l-4-4m4 4l-4 4m0 6H4m0 0l4 4m-4-4l4-4"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<div class="mb-3 rounded-lg bg-blue-50 dark:bg-blue-900/20 p-3">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<svg
|
||||
class="w-4 h-4 inline mr-1"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.selectAllowedModels') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Checkbox List -->
|
||||
<div class="grid grid-cols-2 gap-2 mb-3">
|
||||
<label
|
||||
v-for="model in allModels"
|
||||
:key="model.value"
|
||||
class="flex cursor-pointer items-center rounded-lg border p-3 transition-all hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700"
|
||||
:class="
|
||||
allowedModels.includes(model.value)
|
||||
? 'border-primary-500 bg-primary-50 dark:bg-primary-900/20'
|
||||
: 'border-gray-200'
|
||||
"
|
||||
>
|
||||
<input
|
||||
v-model="allowedModels"
|
||||
type="checkbox"
|
||||
:value="model.value"
|
||||
class="mr-2 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ model.label }}</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{
|
||||
t('admin.accounts.supportsAllModels')
|
||||
}}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else>
|
||||
<div class="mb-3 rounded-lg bg-purple-50 dark:bg-purple-900/20 p-3">
|
||||
<p class="text-xs text-purple-700 dark:text-purple-400">
|
||||
<svg
|
||||
class="w-4 h-4 inline mr-1"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.mapRequestModels') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Mapping List -->
|
||||
<div v-if="modelMappings.length > 0" class="space-y-2 mb-3">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="index"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
v-model="mapping.from"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.requestModel')"
|
||||
/>
|
||||
<svg
|
||||
class="w-4 h-4 text-gray-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M14 5l7 7m0 0l-7 7m7-7H3"
|
||||
/>
|
||||
</svg>
|
||||
<input
|
||||
v-model="mapping.to"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.actualModel')"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
class="p-2 text-red-500 hover:text-red-600 hover:bg-red-50 dark:hover:bg-red-900/20 rounded-lg transition-colors"
|
||||
@click="removeModelMapping(index)"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="w-full rounded-lg border-2 border-dashed border-gray-300 dark:border-dark-500 px-4 py-2 text-gray-600 dark:text-gray-400 transition-colors hover:border-gray-400 hover:text-gray-700 dark:hover:border-dark-400 dark:hover:text-gray-300 mb-3"
|
||||
@click="addModelMapping"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 inline mr-1"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M12 4v16m8-8H4"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
|
||||
<!-- Quick Add Buttons -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in presetMappings"
|
||||
:key="preset.label"
|
||||
type="button"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Custom error codes -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">
|
||||
{{ t('admin.accounts.customErrorCodesHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<input
|
||||
v-model="enableCustomErrorCodes"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="enableCustomErrorCodes" class="space-y-3">
|
||||
<div class="rounded-lg bg-amber-50 dark:bg-amber-900/20 p-3">
|
||||
<p class="text-xs text-amber-700 dark:text-amber-400">
|
||||
<svg
|
||||
class="w-4 h-4 inline mr-1"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.customErrorCodesWarning') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Error Code Buttons -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="code in commonErrorCodes"
|
||||
:key="code.value"
|
||||
type="button"
|
||||
:class="[
|
||||
'rounded-lg px-3 py-1.5 text-sm font-medium transition-colors',
|
||||
selectedErrorCodes.includes(code.value)
|
||||
? 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400 ring-1 ring-red-500'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500',
|
||||
]"
|
||||
@click="toggleErrorCode(code.value)"
|
||||
>
|
||||
{{ code.value }} {{ code.label }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Manual input -->
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
v-model="customErrorCodeInput"
|
||||
type="number"
|
||||
min="100"
|
||||
max="599"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.enterErrorCode')"
|
||||
@keyup.enter="addCustomErrorCode"
|
||||
/>
|
||||
<button type="button" class="btn btn-secondary px-3" @click="addCustomErrorCode">
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M12 4v16m8-8H4"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Selected codes summary -->
|
||||
<div class="flex flex-wrap gap-1.5">
|
||||
<span
|
||||
v-for="code in selectedErrorCodes.sort((a, b) => a - b)"
|
||||
:key="code"
|
||||
class="inline-flex items-center gap-1 rounded-full bg-red-100 dark:bg-red-900/30 px-2.5 py-0.5 text-sm font-medium text-red-700 dark:text-red-400"
|
||||
>
|
||||
{{ code }}
|
||||
<button
|
||||
type="button"
|
||||
class="hover:text-red-900 dark:hover:text-red-300"
|
||||
@click="removeErrorCode(code)"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</span>
|
||||
<span v-if="selectedErrorCodes.length === 0" class="text-xs text-gray-400">
|
||||
{{ t('admin.accounts.noneSelectedUsesDefault') }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Intercept warmup requests (Anthropic only) -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex-1 pr-4">
|
||||
<label class="input-label mb-0">{{
|
||||
t('admin.accounts.interceptWarmupRequests')
|
||||
}}</label>
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">
|
||||
{{ t('admin.accounts.interceptWarmupRequestsDesc') }}
|
||||
</p>
|
||||
</div>
|
||||
<input
|
||||
v-model="enableInterceptWarmup"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div v-if="enableInterceptWarmup" class="mt-3">
|
||||
<button
|
||||
type="button"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
interceptWarmupRequests ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600',
|
||||
]"
|
||||
@click="interceptWarmupRequests = !interceptWarmupRequests"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
interceptWarmupRequests ? 'translate-x-5' : 'translate-x-0',
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Proxy -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.proxy') }}</label>
|
||||
<input
|
||||
v-model="enableProxy"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div :class="!enableProxy && 'opacity-50 pointer-events-none'">
|
||||
<ProxySelector v-model="proxyId" :proxies="proxies" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Concurrency & Priority -->
|
||||
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div>
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.concurrency') }}</label>
|
||||
<input
|
||||
v-model="enableConcurrency"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<input
|
||||
v-model.number="concurrency"
|
||||
type="number"
|
||||
min="1"
|
||||
:disabled="!enableConcurrency"
|
||||
class="input"
|
||||
:class="!enableConcurrency && 'opacity-50 cursor-not-allowed'"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.priority') }}</label>
|
||||
<input
|
||||
v-model="enablePriority"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<input
|
||||
v-model.number="priority"
|
||||
type="number"
|
||||
min="1"
|
||||
:disabled="!enablePriority"
|
||||
class="input"
|
||||
:class="!enablePriority && 'opacity-50 cursor-not-allowed'"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Status -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('common.status') }}</label>
|
||||
<input
|
||||
v-model="enableStatus"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div :class="!enableStatus && 'opacity-50 pointer-events-none'">
|
||||
<Select v-model="status" :options="statusOptions" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Groups -->
|
||||
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<label class="input-label mb-0">{{ t('nav.groups') }}</label>
|
||||
<input
|
||||
v-model="enableGroups"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div :class="!enableGroups && 'opacity-50 pointer-events-none'">
|
||||
<GroupSelector v-model="groupIds" :groups="groups" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Action buttons -->
|
||||
<div class="flex justify-end gap-3 pt-4">
|
||||
<button type="button" class="btn btn-secondary" @click="handleClose">
|
||||
{{ t('common.cancel') }}
|
||||
</button>
|
||||
<button type="submit" :disabled="submitting" class="btn btn-primary">
|
||||
<svg
|
||||
v-if="submitting"
|
||||
class="animate-spin -ml-1 mr-2 h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle
|
||||
class="opacity-25"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="currentColor"
|
||||
stroke-width="4"
|
||||
/>
|
||||
<path
|
||||
class="opacity-75"
|
||||
fill="currentColor"
|
||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||
/>
|
||||
</svg>
|
||||
{{
|
||||
submitting ? t('admin.accounts.bulkEdit.updating') : t('admin.accounts.bulkEdit.submit')
|
||||
}}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</Modal>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Proxy, Group } from '@/types'
|
||||
import Modal from '@/components/common/Modal.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
accountIds: number[]
|
||||
proxies: Proxy[]
|
||||
groups: Group[]
|
||||
}
|
||||
|
||||
const props = defineProps<Props>()
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
updated: []
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
// Model mapping type
|
||||
interface ModelMapping {
|
||||
from: string
|
||||
to: string
|
||||
}
|
||||
|
||||
// State - field enable flags
|
||||
const enableBaseUrl = ref(false)
|
||||
const enableModelRestriction = ref(false)
|
||||
const enableCustomErrorCodes = ref(false)
|
||||
const enableInterceptWarmup = ref(false)
|
||||
const enableProxy = ref(false)
|
||||
const enableConcurrency = ref(false)
|
||||
const enablePriority = ref(false)
|
||||
const enableStatus = ref(false)
|
||||
const enableGroups = ref(false)
|
||||
|
||||
// State - field values
|
||||
const submitting = ref(false)
|
||||
const baseUrl = ref('')
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
const proxyId = ref<number | null>(null)
|
||||
const concurrency = ref(1)
|
||||
const priority = ref(1)
|
||||
const status = ref<'active' | 'inactive'>('active')
|
||||
const groupIds = ref<number[]>([])
|
||||
|
||||
// All models list (combined Anthropic + OpenAI)
|
||||
const allModels = [
|
||||
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
||||
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
||||
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
||||
{ value: 'claude-3-5-haiku-20241022', label: 'Claude 3.5 Haiku' },
|
||||
{ value: 'claude-haiku-4-5-20251001', label: 'Claude Haiku 4.5' },
|
||||
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
|
||||
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
||||
]
|
||||
|
||||
// Preset mappings (combined Anthropic + OpenAI)
|
||||
const presetMappings = [
|
||||
{
|
||||
label: 'Sonnet 4',
|
||||
from: 'claude-sonnet-4-20250514',
|
||||
to: 'claude-sonnet-4-20250514',
|
||||
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400',
|
||||
},
|
||||
{
|
||||
label: 'Sonnet 4.5',
|
||||
from: 'claude-sonnet-4-5-20250929',
|
||||
to: 'claude-sonnet-4-5-20250929',
|
||||
color:
|
||||
'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400',
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.5',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-opus-4-5-20251101',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400',
|
||||
},
|
||||
{
|
||||
label: 'Opus->Sonnet',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-sonnet-4-5-20250929',
|
||||
color:
|
||||
'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400',
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2',
|
||||
from: 'gpt-5.2-2025-12-11',
|
||||
to: 'gpt-5.2-2025-12-11',
|
||||
color:
|
||||
'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400',
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2 Codex',
|
||||
from: 'gpt-5.2-codex',
|
||||
to: 'gpt-5.2-codex',
|
||||
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400',
|
||||
},
|
||||
{
|
||||
label: 'Max->Codex',
|
||||
from: 'gpt-5.1-codex-max',
|
||||
to: 'gpt-5.1-codex',
|
||||
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400',
|
||||
},
|
||||
]
|
||||
|
||||
// Common HTTP error codes
|
||||
const commonErrorCodes = [
|
||||
{ value: 401, label: 'Unauthorized' },
|
||||
{ value: 403, label: 'Forbidden' },
|
||||
{ value: 429, label: 'Rate Limit' },
|
||||
{ value: 500, label: 'Server Error' },
|
||||
{ value: 502, label: 'Bad Gateway' },
|
||||
{ value: 503, label: 'Unavailable' },
|
||||
{ value: 529, label: 'Overloaded' },
|
||||
]
|
||||
|
||||
const statusOptions = computed(() => [
|
||||
{ value: 'active', label: t('common.active') },
|
||||
{ value: 'inactive', label: t('common.inactive') },
|
||||
])
|
||||
|
||||
// Model mapping helpers
|
||||
const addModelMapping = () => {
|
||||
modelMappings.value.push({ from: '', to: '' })
|
||||
}
|
||||
|
||||
const removeModelMapping = (index: number) => {
|
||||
modelMappings.value.splice(index, 1)
|
||||
}
|
||||
|
||||
const addPresetMapping = (from: string, to: string) => {
|
||||
const exists = modelMappings.value.some((m) => m.from === from)
|
||||
if (exists) {
|
||||
appStore.showInfo(t('admin.accounts.mappingExists', { model: from }))
|
||||
return
|
||||
}
|
||||
modelMappings.value.push({ from, to })
|
||||
}
|
||||
|
||||
// Error code helpers
|
||||
const toggleErrorCode = (code: number) => {
|
||||
const index = selectedErrorCodes.value.indexOf(code)
|
||||
if (index === -1) {
|
||||
selectedErrorCodes.value.push(code)
|
||||
} else {
|
||||
selectedErrorCodes.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
const addCustomErrorCode = () => {
|
||||
const code = customErrorCodeInput.value
|
||||
if (code === null || code < 100 || code > 599) {
|
||||
appStore.showError(t('admin.accounts.invalidErrorCode'))
|
||||
return
|
||||
}
|
||||
if (selectedErrorCodes.value.includes(code)) {
|
||||
appStore.showInfo(t('admin.accounts.errorCodeExists'))
|
||||
return
|
||||
}
|
||||
selectedErrorCodes.value.push(code)
|
||||
customErrorCodeInput.value = null
|
||||
}
|
||||
|
||||
const removeErrorCode = (code: number) => {
|
||||
const index = selectedErrorCodes.value.indexOf(code)
|
||||
if (index !== -1) {
|
||||
selectedErrorCodes.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
const buildModelMappingObject = (): Record<string, string> | null => {
|
||||
const mapping: Record<string, string> = {}
|
||||
|
||||
if (modelRestrictionMode.value === 'whitelist') {
|
||||
for (const model of allowedModels.value) {
|
||||
mapping[model] = model
|
||||
}
|
||||
} else {
|
||||
for (const m of modelMappings.value) {
|
||||
const from = m.from.trim()
|
||||
const to = m.to.trim()
|
||||
if (from && to) {
|
||||
mapping[from] = to
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Object.keys(mapping).length > 0 ? mapping : null
|
||||
}
|
||||
|
||||
const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
const updates: Record<string, unknown> = {}
|
||||
const credentials: Record<string, unknown> = {}
|
||||
let credentialsChanged = false
|
||||
|
||||
if (enableProxy.value) {
|
||||
updates.proxy_id = proxyId.value
|
||||
}
|
||||
|
||||
if (enableConcurrency.value) {
|
||||
updates.concurrency = concurrency.value
|
||||
}
|
||||
|
||||
if (enablePriority.value) {
|
||||
updates.priority = priority.value
|
||||
}
|
||||
|
||||
if (enableStatus.value) {
|
||||
updates.status = status.value
|
||||
}
|
||||
|
||||
if (enableGroups.value) {
|
||||
updates.group_ids = groupIds.value
|
||||
}
|
||||
|
||||
if (enableBaseUrl.value) {
|
||||
const baseUrlValue = baseUrl.value.trim()
|
||||
if (baseUrlValue) {
|
||||
credentials.base_url = baseUrlValue
|
||||
credentialsChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if (enableModelRestriction.value) {
|
||||
const modelMapping = buildModelMappingObject()
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
credentialsChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if (enableCustomErrorCodes.value) {
|
||||
credentials.custom_error_codes_enabled = true
|
||||
credentials.custom_error_codes = [...selectedErrorCodes.value]
|
||||
credentialsChanged = true
|
||||
}
|
||||
|
||||
if (enableInterceptWarmup.value) {
|
||||
credentials.intercept_warmup_requests = interceptWarmupRequests.value
|
||||
credentialsChanged = true
|
||||
}
|
||||
|
||||
if (credentialsChanged) {
|
||||
updates.credentials = credentials
|
||||
}
|
||||
|
||||
return Object.keys(updates).length > 0 ? updates : null
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (props.accountIds.length === 0) {
|
||||
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
|
||||
return
|
||||
}
|
||||
|
||||
const hasAnyFieldEnabled =
|
||||
enableBaseUrl.value ||
|
||||
enableModelRestriction.value ||
|
||||
enableCustomErrorCodes.value ||
|
||||
enableInterceptWarmup.value ||
|
||||
enableProxy.value ||
|
||||
enableConcurrency.value ||
|
||||
enablePriority.value ||
|
||||
enableStatus.value ||
|
||||
enableGroups.value
|
||||
|
||||
if (!hasAnyFieldEnabled) {
|
||||
appStore.showError(t('admin.accounts.bulkEdit.noFieldsSelected'))
|
||||
return
|
||||
}
|
||||
|
||||
const updates = buildUpdatePayload()
|
||||
if (!updates) {
|
||||
appStore.showError(t('admin.accounts.bulkEdit.noFieldsSelected'))
|
||||
return
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
|
||||
try {
|
||||
const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
|
||||
const success = res.success || 0
|
||||
const failed = res.failed || 0
|
||||
|
||||
if (success > 0 && failed === 0) {
|
||||
appStore.showSuccess(t('admin.accounts.bulkEdit.success', { count: success }))
|
||||
} else if (success > 0) {
|
||||
appStore.showError(t('admin.accounts.bulkEdit.partialSuccess', { success, failed }))
|
||||
} else {
|
||||
appStore.showError(t('admin.accounts.bulkEdit.failed'))
|
||||
}
|
||||
|
||||
if (success > 0) {
|
||||
emit('updated')
|
||||
handleClose()
|
||||
}
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.bulkEdit.failed'))
|
||||
console.error('Error bulk updating accounts:', error)
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Reset form when modal closes
|
||||
watch(
|
||||
() => props.show,
|
||||
(newShow) => {
|
||||
if (!newShow) {
|
||||
// Reset all enable flags
|
||||
enableBaseUrl.value = false
|
||||
enableModelRestriction.value = false
|
||||
enableCustomErrorCodes.value = false
|
||||
enableInterceptWarmup.value = false
|
||||
enableProxy.value = false
|
||||
enableConcurrency.value = false
|
||||
enablePriority.value = false
|
||||
enableStatus.value = false
|
||||
enableGroups.value = false
|
||||
|
||||
// Reset all values
|
||||
baseUrl.value = ''
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = []
|
||||
modelMappings.value = []
|
||||
selectedErrorCodes.value = []
|
||||
customErrorCodeInput.value = null
|
||||
interceptWarmupRequests.value = false
|
||||
proxyId.value = null
|
||||
concurrency.value = 1
|
||||
priority.value = 1
|
||||
status.value = 'active'
|
||||
groupIds.value = []
|
||||
}
|
||||
}
|
||||
)
|
||||
</script>
|
||||
170
frontend/src/components/account/SyncFromCrsModal.vue
Normal file
170
frontend/src/components/account/SyncFromCrsModal.vue
Normal file
@@ -0,0 +1,170 @@
|
||||
<template>
|
||||
<Modal
|
||||
:show="show"
|
||||
:title="t('admin.accounts.syncFromCrsTitle')"
|
||||
size="lg"
|
||||
close-on-click-outside
|
||||
@close="handleClose"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<div class="text-sm text-gray-600 dark:text-dark-300">
|
||||
{{ t('admin.accounts.syncFromCrsDesc') }}
|
||||
</div>
|
||||
<div class="text-xs text-gray-500 dark:text-dark-400 bg-gray-50 dark:bg-dark-700/60 rounded-lg p-3">
|
||||
已有账号仅同步 CRS 返回的字段,缺失字段保持原值;凭据按键合并,不会清空未下发的键;未勾选"同步代理"时保留原有代理。
|
||||
</div>
|
||||
<div class="text-xs text-amber-600 dark:text-amber-400 bg-amber-50 dark:bg-amber-900/20 border border-amber-200 dark:border-amber-800 rounded-lg p-3">
|
||||
{{ t('admin.accounts.crsVersionRequirement') }}
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 gap-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.crsBaseUrl') }}</label>
|
||||
<input
|
||||
v-model="form.base_url"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.crsBaseUrlPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.crsUsername') }}</label>
|
||||
<input
|
||||
v-model="form.username"
|
||||
type="text"
|
||||
class="input"
|
||||
autocomplete="username"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.crsPassword') }}</label>
|
||||
<input
|
||||
v-model="form.password"
|
||||
type="password"
|
||||
class="input"
|
||||
autocomplete="current-password"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-dark-300">
|
||||
<input v-model="form.sync_proxies" type="checkbox" class="rounded border-gray-300 dark:border-dark-600" />
|
||||
{{ t('admin.accounts.syncProxies') }}
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div v-if="result" class="rounded-xl border border-gray-200 dark:border-dark-700 p-4 space-y-2">
|
||||
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.syncResult') }}
|
||||
</div>
|
||||
<div class="text-sm text-gray-700 dark:text-dark-300">
|
||||
{{ t('admin.accounts.syncResultSummary', result) }}
|
||||
</div>
|
||||
|
||||
<div v-if="errorItems.length" class="mt-2">
|
||||
<div class="text-sm font-medium text-red-600 dark:text-red-400">
|
||||
{{ t('admin.accounts.syncErrors') }}
|
||||
</div>
|
||||
<div class="mt-2 max-h-48 overflow-auto rounded-lg bg-gray-50 dark:bg-dark-800 p-3 text-xs font-mono">
|
||||
<div v-for="(item, idx) in errorItems" :key="idx" class="whitespace-pre-wrap">
|
||||
{{ item.kind }} {{ item.crs_account_id }} — {{ item.action }}{{ item.error ? `: ${item.error}` : '' }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<div class="flex justify-end gap-3">
|
||||
<button class="btn btn-secondary" :disabled="syncing" @click="handleClose">
|
||||
{{ t('common.cancel') }}
|
||||
</button>
|
||||
<button class="btn btn-primary" :disabled="syncing" @click="handleSync">
|
||||
{{ syncing ? t('admin.accounts.syncing') : t('admin.accounts.syncNow') }}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</Modal>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, reactive, ref, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Modal from '@/components/common/Modal.vue'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
}
|
||||
|
||||
interface Emits {
|
||||
(e: 'close'): void
|
||||
(e: 'synced'): void
|
||||
}
|
||||
|
||||
const props = defineProps<Props>()
|
||||
const emit = defineEmits<Emits>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
const syncing = ref(false)
|
||||
const result = ref<Awaited<ReturnType<typeof adminAPI.accounts.syncFromCrs>> | null>(null)
|
||||
|
||||
const form = reactive({
|
||||
base_url: '',
|
||||
username: '',
|
||||
password: '',
|
||||
sync_proxies: true
|
||||
})
|
||||
|
||||
const errorItems = computed(() => {
|
||||
if (!result.value?.items) return []
|
||||
return result.value.items.filter((i) => i.action === 'failed' || i.action === 'skipped')
|
||||
})
|
||||
|
||||
watch(
|
||||
() => props.show,
|
||||
(open) => {
|
||||
if (open) {
|
||||
result.value = null
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
const handleClose = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const handleSync = async () => {
|
||||
if (!form.base_url.trim() || !form.username.trim() || !form.password.trim()) {
|
||||
appStore.showError(t('admin.accounts.syncMissingFields'))
|
||||
return
|
||||
}
|
||||
|
||||
syncing.value = true
|
||||
try {
|
||||
const res = await adminAPI.accounts.syncFromCrs({
|
||||
base_url: form.base_url.trim(),
|
||||
username: form.username.trim(),
|
||||
password: form.password,
|
||||
sync_proxies: form.sync_proxies
|
||||
})
|
||||
result.value = res
|
||||
|
||||
if (res.failed > 0) {
|
||||
appStore.showError(t('admin.accounts.syncCompletedWithErrors', res))
|
||||
} else {
|
||||
appStore.showSuccess(t('admin.accounts.syncCompleted', res))
|
||||
emit('synced')
|
||||
}
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || t('admin.accounts.syncFailed'))
|
||||
} finally {
|
||||
syncing.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -1,5 +1,6 @@
|
||||
export { default as CreateAccountModal } from './CreateAccountModal.vue'
|
||||
export { default as EditAccountModal } from './EditAccountModal.vue'
|
||||
export { default as BulkEditAccountModal } from './BulkEditAccountModal.vue'
|
||||
export { default as ReAuthAccountModal } from './ReAuthAccountModal.vue'
|
||||
export { default as OAuthAuthorizationFlow } from './OAuthAuthorizationFlow.vue'
|
||||
export { default as AccountStatusIndicator } from './AccountStatusIndicator.vue'
|
||||
@@ -8,3 +9,4 @@ export { default as UsageProgressBar } from './UsageProgressBar.vue'
|
||||
export { default as AccountStatsModal } from './AccountStatsModal.vue'
|
||||
export { default as AccountTestModal } from './AccountTestModal.vue'
|
||||
export { default as AccountTodayStatsCell } from './AccountTodayStatsCell.vue'
|
||||
export { default as SyncFromCrsModal } from './SyncFromCrsModal.vue'
|
||||
|
||||
@@ -17,11 +17,14 @@ export default {
|
||||
},
|
||||
features: {
|
||||
unifiedGateway: 'Unified API Gateway',
|
||||
unifiedGatewayDesc: 'Convert Claude subscriptions to API endpoints. Access AI capabilities through standard /v1/messages interface.',
|
||||
unifiedGatewayDesc:
|
||||
'Convert Claude subscriptions to API endpoints. Access AI capabilities through standard /v1/messages interface.',
|
||||
multiAccount: 'Multi-Account Pool',
|
||||
multiAccountDesc: 'Manage multiple upstream accounts with smart load balancing. Support OAuth and API Key authentication.',
|
||||
multiAccountDesc:
|
||||
'Manage multiple upstream accounts with smart load balancing. Support OAuth and API Key authentication.',
|
||||
balanceQuota: 'Balance & Quota',
|
||||
balanceQuotaDesc: 'Token-based billing with precise usage tracking. Manage quotas and recharge with redeem codes.',
|
||||
balanceQuotaDesc:
|
||||
'Token-based billing with precise usage tracking. Manage quotas and recharge with redeem codes.',
|
||||
},
|
||||
providers: {
|
||||
title: 'Supported Providers',
|
||||
@@ -235,7 +238,8 @@ export default {
|
||||
useKey: 'Use Key',
|
||||
useKeyModal: {
|
||||
title: 'Use API Key',
|
||||
description: 'Add the following environment variables to your terminal profile or run directly in terminal to configure API access.',
|
||||
description:
|
||||
'Add the following environment variables to your terminal profile or run directly in terminal to configure API access.',
|
||||
copy: 'Copy',
|
||||
copied: 'Copied',
|
||||
note: 'These environment variables will be active in the current terminal session. For permanent configuration, add them to ~/.bashrc, ~/.zshrc, or the appropriate configuration file.',
|
||||
@@ -517,7 +521,8 @@ export default {
|
||||
failedToLoadApiKeys: 'Failed to load user API keys',
|
||||
deleteConfirm: "Are you sure you want to delete '{email}'? This action cannot be undone.",
|
||||
setAllowedGroups: 'Set Allowed Groups',
|
||||
allowedGroupsHint: 'Select which standard groups this user can use. Subscription groups are managed separately.',
|
||||
allowedGroupsHint:
|
||||
'Select which standard groups this user can use. Subscription groups are managed separately.',
|
||||
noStandardGroups: 'No standard groups available',
|
||||
allowAllGroups: 'Allow All Groups',
|
||||
allowAllGroupsHint: 'User can use any non-exclusive group',
|
||||
@@ -529,8 +534,10 @@ export default {
|
||||
depositAmount: 'Deposit Amount',
|
||||
withdrawAmount: 'Withdraw Amount',
|
||||
currentBalance: 'Current Balance',
|
||||
depositNotesPlaceholder: 'e.g., New user registration bonus, promotional credit, compensation, etc.',
|
||||
withdrawNotesPlaceholder: 'e.g., Service issue refund, incorrect charge reversal, account closure refund, etc.',
|
||||
depositNotesPlaceholder:
|
||||
'e.g., New user registration bonus, promotional credit, compensation, etc.',
|
||||
withdrawNotesPlaceholder:
|
||||
'e.g., Service issue refund, incorrect charge reversal, account closure refund, etc.',
|
||||
notesOptional: 'Notes are optional but helpful for record keeping',
|
||||
amountHint: 'Please enter a positive amount',
|
||||
newBalance: 'New Balance',
|
||||
@@ -597,12 +604,15 @@ export default {
|
||||
failedToCreate: 'Failed to create group',
|
||||
failedToUpdate: 'Failed to update group',
|
||||
failedToDelete: 'Failed to delete group',
|
||||
deleteConfirm: "Are you sure you want to delete '{name}'? All associated API keys will no longer belong to any group.",
|
||||
deleteConfirmSubscription: "Are you sure you want to delete subscription group '{name}'? This will invalidate all API keys bound to this subscription and delete all related subscription records. This action cannot be undone.",
|
||||
deleteConfirm:
|
||||
"Are you sure you want to delete '{name}'? All associated API keys will no longer belong to any group.",
|
||||
deleteConfirmSubscription:
|
||||
"Are you sure you want to delete subscription group '{name}'? This will invalidate all API keys bound to this subscription and delete all related subscription records. This action cannot be undone.",
|
||||
subscription: {
|
||||
title: 'Subscription Settings',
|
||||
type: 'Billing Type',
|
||||
typeHint: 'Standard billing deducts from user balance. Subscription mode uses quota limits instead.',
|
||||
typeHint:
|
||||
'Standard billing deducts from user balance. Subscription mode uses quota limits instead.',
|
||||
typeNotEditable: 'Billing type cannot be changed after group creation.',
|
||||
standard: 'Standard (Balance)',
|
||||
subscription: 'Subscription (Quota)',
|
||||
@@ -674,7 +684,8 @@ export default {
|
||||
failedToAssign: 'Failed to assign subscription',
|
||||
failedToExtend: 'Failed to extend subscription',
|
||||
failedToRevoke: 'Failed to revoke subscription',
|
||||
revokeConfirm: "Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.",
|
||||
revokeConfirm:
|
||||
"Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.",
|
||||
},
|
||||
|
||||
// Accounts
|
||||
@@ -682,6 +693,26 @@ export default {
|
||||
title: 'Account Management',
|
||||
description: 'Manage AI platform accounts and credentials',
|
||||
createAccount: 'Create Account',
|
||||
syncFromCrs: 'Sync from CRS',
|
||||
syncFromCrsTitle: 'Sync Accounts from CRS',
|
||||
syncFromCrsDesc:
|
||||
'Sync accounts from claude-relay-service (CRS) into this system (CRS is called server-to-server).',
|
||||
crsVersionRequirement: '⚠️ Note: CRS version must be ≥ v1.1.240 to support this feature',
|
||||
crsBaseUrl: 'CRS Base URL',
|
||||
crsBaseUrlPlaceholder: 'e.g. http://127.0.0.1:3000',
|
||||
crsUsername: 'Username',
|
||||
crsPassword: 'Password',
|
||||
syncProxies: 'Also sync proxies (match by host/port/auth or create)',
|
||||
syncNow: 'Sync Now',
|
||||
syncing: 'Syncing...',
|
||||
syncMissingFields: 'Please fill base URL, username and password',
|
||||
syncResult: 'Sync Result',
|
||||
syncResultSummary: 'Created {created}, updated {updated}, skipped {skipped}, failed {failed}',
|
||||
syncErrors: 'Errors / Skipped Details',
|
||||
syncCompleted: 'Sync completed: created {created}, updated {updated}',
|
||||
syncCompletedWithErrors:
|
||||
'Sync completed with errors: failed {failed} (created {created}, updated {updated})',
|
||||
syncFailed: 'Sync failed',
|
||||
editAccount: 'Edit Account',
|
||||
deleteAccount: 'Delete Account',
|
||||
searchAccounts: 'Search accounts...',
|
||||
@@ -726,6 +757,32 @@ export default {
|
||||
tokenRefreshed: 'Token refreshed successfully',
|
||||
accountDeleted: 'Account deleted successfully',
|
||||
rateLimitCleared: 'Rate limit cleared successfully',
|
||||
bulkActions: {
|
||||
selected: '{count} account(s) selected',
|
||||
selectCurrentPage: 'Select this page',
|
||||
clear: 'Clear selection',
|
||||
edit: 'Bulk Edit',
|
||||
delete: 'Bulk Delete',
|
||||
},
|
||||
bulkEdit: {
|
||||
title: 'Bulk Edit Accounts',
|
||||
selectionInfo:
|
||||
'{count} account(s) selected. Only checked or filled fields will be updated; others stay unchanged.',
|
||||
baseUrlPlaceholder: 'https://api.anthropic.com or https://api.openai.com',
|
||||
baseUrlNotice: 'Applies to API Key accounts only; leave empty to keep existing value',
|
||||
submit: 'Update Accounts',
|
||||
updating: 'Updating...',
|
||||
success: 'Updated {count} account(s)',
|
||||
partialSuccess: 'Partially updated: {success} succeeded, {failed} failed',
|
||||
failed: 'Bulk update failed',
|
||||
noSelection: 'Please select accounts to edit',
|
||||
noFieldsSelected: 'Select at least one field to update',
|
||||
},
|
||||
bulkDeleteTitle: 'Bulk Delete Accounts',
|
||||
bulkDeleteConfirm: 'Delete the selected {count} account(s)? This action cannot be undone.',
|
||||
bulkDeleteSuccess: 'Deleted {count} account(s)',
|
||||
bulkDeletePartial: 'Partially deleted: {success} succeeded, {failed} failed',
|
||||
bulkDeleteFailed: 'Bulk delete failed',
|
||||
resetStatus: 'Reset Status',
|
||||
statusReset: 'Account status reset successfully',
|
||||
failedToResetStatus: 'Failed to reset account status',
|
||||
@@ -753,7 +810,8 @@ export default {
|
||||
modelWhitelist: 'Model Whitelist',
|
||||
modelMapping: 'Model Mapping',
|
||||
selectAllowedModels: 'Select allowed models. Leave empty to support all models.',
|
||||
mapRequestModels: 'Map request models to actual models. Left is the requested model, right is the actual model sent to API.',
|
||||
mapRequestModels:
|
||||
'Map request models to actual models. Left is the requested model, right is the actual model sent to API.',
|
||||
selectedModels: 'Selected {count} model(s)',
|
||||
supportsAllModels: '(supports all models)',
|
||||
requestModel: 'Request model',
|
||||
@@ -762,14 +820,16 @@ export default {
|
||||
mappingExists: 'Mapping for {model} already exists',
|
||||
customErrorCodes: 'Custom Error Codes',
|
||||
customErrorCodesHint: 'Only stop scheduling for selected error codes',
|
||||
customErrorCodesWarning: 'Only selected error codes will stop scheduling. Other errors will return 500.',
|
||||
customErrorCodesWarning:
|
||||
'Only selected error codes will stop scheduling. Other errors will return 500.',
|
||||
selectedErrorCodes: 'Selected',
|
||||
noneSelectedUsesDefault: 'None selected (uses default policy)',
|
||||
enterErrorCode: 'Enter error code (100-599)',
|
||||
invalidErrorCode: 'Please enter a valid HTTP error code (100-599)',
|
||||
errorCodeExists: 'This error code is already selected',
|
||||
interceptWarmupRequests: 'Intercept Warmup Requests',
|
||||
interceptWarmupRequestsDesc: 'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
|
||||
interceptWarmupRequestsDesc:
|
||||
'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
|
||||
proxy: 'Proxy',
|
||||
noProxy: 'No Proxy',
|
||||
concurrency: 'Concurrency',
|
||||
@@ -792,11 +852,13 @@ export default {
|
||||
authMethod: 'Authorization Method',
|
||||
manualAuth: 'Manual Authorization',
|
||||
cookieAutoAuth: 'Cookie Auto-Auth',
|
||||
cookieAutoAuthDesc: 'Use claude.ai sessionKey to automatically complete OAuth authorization without manually opening browser.',
|
||||
cookieAutoAuthDesc:
|
||||
'Use claude.ai sessionKey to automatically complete OAuth authorization without manually opening browser.',
|
||||
sessionKey: 'sessionKey',
|
||||
keysCount: '{count} keys',
|
||||
batchCreateAccounts: 'Will batch create {count} accounts',
|
||||
sessionKeyPlaceholder: 'One sessionKey per line, e.g.:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
|
||||
sessionKeyPlaceholder:
|
||||
'One sessionKey per line, e.g.:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
|
||||
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
|
||||
howToGetSessionKey: 'How to get sessionKey',
|
||||
step1: 'Login to <strong>claude.ai</strong> in your browser',
|
||||
@@ -814,10 +876,13 @@ export default {
|
||||
generating: 'Generating...',
|
||||
regenerate: 'Regenerate',
|
||||
step2OpenUrl: 'Open the URL in your browser and complete authorization',
|
||||
openUrlDesc: 'Open the authorization URL in a new tab, log in to your Claude account and authorize.',
|
||||
proxyWarning: '<strong>Note:</strong> If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
|
||||
openUrlDesc:
|
||||
'Open the authorization URL in a new tab, log in to your Claude account and authorize.',
|
||||
proxyWarning:
|
||||
'<strong>Note:</strong> If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
|
||||
step3EnterCode: 'Enter the Authorization Code',
|
||||
authCodeDesc: 'After authorization is complete, the page will display an <strong>Authorization Code</strong>. Copy and paste it below:',
|
||||
authCodeDesc:
|
||||
'After authorization is complete, the page will display an <strong>Authorization Code</strong>. Copy and paste it below:',
|
||||
authCode: 'Authorization Code',
|
||||
authCodePlaceholder: 'Paste the Authorization Code from Claude page...',
|
||||
authCodeHint: 'Paste the Authorization Code copied from the Claude page',
|
||||
@@ -835,13 +900,18 @@ export default {
|
||||
step1GenerateUrl: 'Click the button below to generate the authorization URL',
|
||||
generateAuthUrl: 'Generate Auth URL',
|
||||
step2OpenUrl: 'Open the URL in your browser and complete authorization',
|
||||
openUrlDesc: 'Open the authorization URL in a new tab, log in to your OpenAI account and authorize.',
|
||||
importantNotice: '<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to <code>http://localhost...</code>, the authorization is complete.',
|
||||
openUrlDesc:
|
||||
'Open the authorization URL in a new tab, log in to your OpenAI account and authorize.',
|
||||
importantNotice:
|
||||
'<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to <code>http://localhost...</code>, the authorization is complete.',
|
||||
step3EnterCode: 'Enter Authorization URL or Code',
|
||||
authCodeDesc: 'After authorization is complete, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
|
||||
authCodeDesc:
|
||||
'After authorization is complete, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
|
||||
authCode: 'Authorization URL or Code',
|
||||
authCodePlaceholder: 'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
|
||||
authCodeHint: 'You can copy the entire URL or just the code parameter value, the system will auto-detect',
|
||||
authCodePlaceholder:
|
||||
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
|
||||
authCodeHint:
|
||||
'You can copy the entire URL or just the code parameter value, the system will auto-detect',
|
||||
},
|
||||
},
|
||||
// Re-Auth Modal
|
||||
@@ -941,8 +1011,10 @@ export default {
|
||||
standardAdd: 'Standard Add',
|
||||
batchAdd: 'Quick Add',
|
||||
batchInput: 'Proxy List',
|
||||
batchInputPlaceholder: "Enter one proxy per line in the following formats:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
|
||||
batchInputHint: "Supports http, https, socks5 protocols. Format: protocol://[user:pass{'@'}]host:port",
|
||||
batchInputPlaceholder:
|
||||
"Enter one proxy per line in the following formats:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
|
||||
batchInputHint:
|
||||
"Supports http, https, socks5 protocols. Format: protocol://[user:pass{'@'}]host:port",
|
||||
parsedCount: '{count} valid',
|
||||
invalidCount: '{count} invalid',
|
||||
duplicateCount: '{count} duplicate',
|
||||
@@ -965,7 +1037,8 @@ export default {
|
||||
failedToUpdate: 'Failed to update proxy',
|
||||
failedToDelete: 'Failed to delete proxy',
|
||||
failedToTest: 'Failed to test proxy',
|
||||
deleteConfirm: "Are you sure you want to delete '{name}'? Accounts using this proxy will have their proxy removed.",
|
||||
deleteConfirm:
|
||||
"Are you sure you want to delete '{name}'? Accounts using this proxy will have their proxy removed.",
|
||||
},
|
||||
|
||||
// Redeem Codes
|
||||
@@ -994,8 +1067,10 @@ export default {
|
||||
exportCsv: 'Export CSV',
|
||||
deleteAllUnused: 'Delete All Unused Codes',
|
||||
deleteCode: 'Delete Redeem Code',
|
||||
deleteCodeConfirm: 'Are you sure you want to delete this redeem code? This action cannot be undone.',
|
||||
deleteAllUnusedConfirm: 'Are you sure you want to delete all unused (active) redeem codes? This action cannot be undone.',
|
||||
deleteCodeConfirm:
|
||||
'Are you sure you want to delete this redeem code? This action cannot be undone.',
|
||||
deleteAllUnusedConfirm:
|
||||
'Are you sure you want to delete all unused (active) redeem codes? This action cannot be undone.',
|
||||
deleteAll: 'Delete All',
|
||||
generateCodesTitle: 'Generate Redeem Codes',
|
||||
generatedSuccessfully: 'Generated Successfully',
|
||||
@@ -1075,7 +1150,8 @@ export default {
|
||||
siteSubtitle: 'Site Subtitle',
|
||||
siteSubtitleHint: 'Displayed on login and register pages',
|
||||
apiBaseUrl: 'API Base URL',
|
||||
apiBaseUrlHint: 'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
|
||||
apiBaseUrlHint:
|
||||
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
|
||||
contactInfo: 'Contact Info',
|
||||
contactInfoPlaceholder: 'e.g., QQ: 123456789',
|
||||
contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.',
|
||||
@@ -1125,7 +1201,8 @@ export default {
|
||||
create: 'Create Key',
|
||||
creating: 'Creating...',
|
||||
regenerateConfirm: 'Are you sure? The current key will be immediately invalidated.',
|
||||
deleteConfirm: 'Are you sure you want to delete the admin API key? External integrations will stop working.',
|
||||
deleteConfirm:
|
||||
'Are you sure you want to delete the admin API key? External integrations will stop working.',
|
||||
keyGenerated: 'New admin API key generated',
|
||||
keyDeleted: 'Admin API key deleted',
|
||||
copyKey: 'Copy Key',
|
||||
@@ -1191,7 +1268,8 @@ export default {
|
||||
title: 'My Subscriptions',
|
||||
description: 'View your subscription plans and usage',
|
||||
noActiveSubscriptions: 'No Active Subscriptions',
|
||||
noActiveSubscriptionsDesc: 'You don\'t have any active subscriptions. Contact administrator to get one.',
|
||||
noActiveSubscriptionsDesc:
|
||||
"You don't have any active subscriptions. Contact administrator to get one.",
|
||||
status: {
|
||||
active: 'Active',
|
||||
expired: 'Expired',
|
||||
|
||||
@@ -620,7 +620,8 @@ export default {
|
||||
editGroup: '编辑分组',
|
||||
deleteGroup: '删除分组',
|
||||
deleteConfirm: "确定要删除分组 '{name}' 吗?所有关联的 API 密钥将不再属于任何分组。",
|
||||
deleteConfirmSubscription: "确定要删除订阅分组 '{name}' 吗?此操作会让所有绑定此订阅的用户的 API Key 失效,并删除所有相关的订阅记录。此操作无法撤销。",
|
||||
deleteConfirmSubscription:
|
||||
"确定要删除订阅分组 '{name}' 吗?此操作会让所有绑定此订阅的用户的 API Key 失效,并删除所有相关的订阅记录。此操作无法撤销。",
|
||||
columns: {
|
||||
name: '名称',
|
||||
platform: '平台',
|
||||
@@ -780,6 +781,25 @@ export default {
|
||||
title: '账号管理',
|
||||
description: '管理 AI 平台账号和 Cookie',
|
||||
createAccount: '添加账号',
|
||||
syncFromCrs: '从 CRS 同步',
|
||||
syncFromCrsTitle: '从 CRS 同步账号',
|
||||
syncFromCrsDesc:
|
||||
'将 claude-relay-service(CRS)中的账号同步到当前系统(不会在浏览器侧直接请求 CRS)。',
|
||||
crsVersionRequirement: '⚠️ 注意:CRS 版本必须 ≥ v1.1.240 才支持此功能',
|
||||
crsBaseUrl: 'CRS 服务地址',
|
||||
crsBaseUrlPlaceholder: '例如:http://127.0.0.1:3000',
|
||||
crsUsername: '用户名',
|
||||
crsPassword: '密码',
|
||||
syncProxies: '同时同步代理(按 host/port/账号匹配或自动创建)',
|
||||
syncNow: '开始同步',
|
||||
syncing: '同步中...',
|
||||
syncMissingFields: '请填写服务地址、用户名和密码',
|
||||
syncResult: '同步结果',
|
||||
syncResultSummary: '创建 {created},更新 {updated},跳过 {skipped},失败 {failed}',
|
||||
syncErrors: '错误/跳过详情',
|
||||
syncCompleted: '同步完成:创建 {created},更新 {updated}',
|
||||
syncCompletedWithErrors: '同步完成但有错误:失败 {failed}(创建 {created},更新 {updated})',
|
||||
syncFailed: '同步失败',
|
||||
editAccount: '编辑账号',
|
||||
deleteAccount: '删除账号',
|
||||
deleteConfirmMessage: "确定要删除账号 '{name}' 吗?",
|
||||
@@ -859,6 +879,31 @@ export default {
|
||||
accountCreatedSuccess: '账号添加成功',
|
||||
accountUpdatedSuccess: '账号更新成功',
|
||||
accountDeletedSuccess: '账号删除成功',
|
||||
bulkActions: {
|
||||
selected: '已选择 {count} 个账号',
|
||||
selectCurrentPage: '本页全选',
|
||||
clear: '清除选择',
|
||||
edit: '批量编辑账号',
|
||||
delete: '批量删除',
|
||||
},
|
||||
bulkEdit: {
|
||||
title: '批量编辑账号',
|
||||
selectionInfo: '已选择 {count} 个账号。只更新您勾选或填写的字段,未勾选的字段保持不变。',
|
||||
baseUrlPlaceholder: 'https://api.anthropic.com 或 https://api.openai.com',
|
||||
baseUrlNotice: '仅适用于 API Key 账号,留空则不修改',
|
||||
submit: '批量更新',
|
||||
updating: '更新中...',
|
||||
success: '成功更新 {count} 个账号',
|
||||
partialSuccess: '部分更新成功:成功 {success} 个,失败 {failed} 个',
|
||||
failed: '批量更新失败',
|
||||
noSelection: '请选择要编辑的账号',
|
||||
noFieldsSelected: '请至少选择一个要更新的字段',
|
||||
},
|
||||
bulkDeleteTitle: '批量删除账号',
|
||||
bulkDeleteConfirm: '确定要删除选中的 {count} 个账号吗?此操作无法撤销。',
|
||||
bulkDeleteSuccess: '成功删除 {count} 个账号',
|
||||
bulkDeletePartial: '部分删除成功:成功 {success} 个,失败 {failed} 个',
|
||||
bulkDeleteFailed: '批量删除失败',
|
||||
resetStatus: '重置状态',
|
||||
statusReset: '账号状态已重置',
|
||||
failedToResetStatus: '重置账号状态失败',
|
||||
@@ -931,7 +976,8 @@ export default {
|
||||
sessionKey: 'sessionKey',
|
||||
keysCount: '{count} 个密钥',
|
||||
batchCreateAccounts: '将批量创建 {count} 个账号',
|
||||
sessionKeyPlaceholder: '每行一个 sessionKey,例如:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
|
||||
sessionKeyPlaceholder:
|
||||
'每行一个 sessionKey,例如:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
|
||||
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
|
||||
howToGetSessionKey: '如何获取 sessionKey',
|
||||
step1: '在浏览器中登录 <strong>claude.ai</strong>',
|
||||
@@ -950,7 +996,8 @@ export default {
|
||||
regenerate: '重新生成',
|
||||
step2OpenUrl: '在浏览器中打开 URL 并完成授权',
|
||||
openUrlDesc: '在新标签页中打开授权 URL,登录您的 Claude 账号并授权。',
|
||||
proxyWarning: '<strong>注意:</strong>如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
|
||||
proxyWarning:
|
||||
'<strong>注意:</strong>如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
|
||||
step3EnterCode: '输入授权码',
|
||||
authCodeDesc: '授权完成后,页面会显示一个 <strong>授权码</strong>。复制并粘贴到下方:',
|
||||
authCode: '授权码',
|
||||
@@ -971,11 +1018,14 @@ export default {
|
||||
generateAuthUrl: '生成授权链接',
|
||||
step2OpenUrl: '在浏览器中打开链接并完成授权',
|
||||
openUrlDesc: '请在新标签页中打开授权链接,登录您的 OpenAI 账户并授权。',
|
||||
importantNotice: '<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
|
||||
importantNotice:
|
||||
'<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
|
||||
step3EnterCode: '输入授权链接或 Code',
|
||||
authCodeDesc: '授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
|
||||
authCodeDesc:
|
||||
'授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
|
||||
authCode: '授权链接或 Code',
|
||||
authCodePlaceholder: '方式1:复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2:仅复制 code 参数的值',
|
||||
authCodePlaceholder:
|
||||
'方式1:复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2:仅复制 code 参数的值',
|
||||
authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别',
|
||||
},
|
||||
},
|
||||
@@ -1111,7 +1161,8 @@ export default {
|
||||
standardAdd: '标准添加',
|
||||
batchAdd: '快捷添加',
|
||||
batchInput: '代理列表',
|
||||
batchInputPlaceholder: "每行输入一个代理,支持以下格式:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
|
||||
batchInputPlaceholder:
|
||||
"每行输入一个代理,支持以下格式:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
|
||||
batchInputHint: "支持 http、https、socks5 协议,格式:协议://[用户名:密码{'@'}]主机:端口",
|
||||
parsedCount: '有效 {count} 个',
|
||||
invalidCount: '无效 {count} 个',
|
||||
|
||||
@@ -366,13 +366,24 @@ export interface AccountUsageInfo {
|
||||
|
||||
// OpenAI Codex usage snapshot (from response headers)
|
||||
export interface CodexUsageSnapshot {
|
||||
codex_primary_used_percent?: number; // Weekly limit usage percentage
|
||||
codex_primary_reset_after_seconds?: number; // Seconds until weekly reset
|
||||
codex_primary_window_minutes?: number; // Weekly window in minutes
|
||||
codex_secondary_used_percent?: number; // 5h limit usage percentage
|
||||
codex_secondary_reset_after_seconds?: number; // Seconds until 5h reset
|
||||
codex_secondary_window_minutes?: number; // 5h window in minutes
|
||||
// Legacy fields (kept for backwards compatibility)
|
||||
// NOTE: The naming is ambiguous - actual window type is determined by window_minutes value
|
||||
codex_primary_used_percent?: number; // Usage percentage (check window_minutes for actual window type)
|
||||
codex_primary_reset_after_seconds?: number; // Seconds until reset
|
||||
codex_primary_window_minutes?: number; // Window in minutes
|
||||
codex_secondary_used_percent?: number; // Usage percentage (check window_minutes for actual window type)
|
||||
codex_secondary_reset_after_seconds?: number; // Seconds until reset
|
||||
codex_secondary_window_minutes?: number; // Window in minutes
|
||||
codex_primary_over_secondary_percent?: number; // Overflow ratio
|
||||
|
||||
// Canonical fields (normalized by backend, use these preferentially)
|
||||
codex_5h_used_percent?: number; // 5-hour window usage percentage
|
||||
codex_5h_reset_after_seconds?: number; // Seconds until 5h window reset
|
||||
codex_5h_window_minutes?: number; // 5h window in minutes (should be ~300)
|
||||
codex_7d_used_percent?: number; // 7-day window usage percentage
|
||||
codex_7d_reset_after_seconds?: number; // Seconds until 7d window reset
|
||||
codex_7d_window_minutes?: number; // 7d window in minutes (should be ~10080)
|
||||
|
||||
codex_usage_updated_at?: string; // Last update timestamp
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,15 @@
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0l3.181 3.183a8.25 8.25 0 0013.803-3.7M4.031 9.865a8.25 8.25 0 0113.803-3.7l3.181 3.182m0-4.991v4.99" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
@click="showCrsSyncModal = true"
|
||||
class="btn btn-secondary"
|
||||
title="从 CRS 同步"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M20.25 6.375c0 2.278-3.694 4.125-8.25 4.125S3.75 8.653 3.75 6.375m16.5 0c0-2.278-3.694-4.125-8.25-4.125S3.75 4.097 3.75 6.375m16.5 0v11.25c0 2.278-3.694 4.125-8.25 4.125s-8.25-1.847-8.25-4.125V6.375m16.5 0v3.75m-16.5-3.75v3.75m16.5 0v3.75C20.25 16.153 16.556 18 12 18s-8.25-1.847-8.25-4.125v-3.75m16.5 0c0 2.278-3.694 4.125-8.25 4.125s-8.25-1.847-8.25-4.125" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
@click="showCreateModal = true"
|
||||
class="btn btn-primary"
|
||||
@@ -66,9 +75,63 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bulk Actions Bar -->
|
||||
<div v-if="selectedAccountIds.length > 0" class="card bg-primary-50 dark:bg-primary-900/20 border-primary-200 dark:border-primary-800 px-4 py-3">
|
||||
<div class="flex flex-wrap items-center justify-between gap-3">
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<span class="text-sm font-medium text-primary-900 dark:text-primary-100">
|
||||
{{ t('admin.accounts.bulkActions.selected', { count: selectedAccountIds.length }) }}
|
||||
</span>
|
||||
<button
|
||||
@click="selectCurrentPageAccounts"
|
||||
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
|
||||
>
|
||||
{{ t('admin.accounts.bulkActions.selectCurrentPage') }}
|
||||
</button>
|
||||
<span class="text-gray-300 dark:text-primary-800">•</span>
|
||||
<button
|
||||
@click="selectedAccountIds = []"
|
||||
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
|
||||
>
|
||||
{{ t('admin.accounts.bulkActions.clear') }}
|
||||
</button>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<button
|
||||
@click="handleBulkDelete"
|
||||
class="btn btn-danger btn-sm"
|
||||
>
|
||||
<svg class="w-4 h-4 mr-1.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round"
|
||||
d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" />
|
||||
</svg>
|
||||
{{ t('admin.accounts.bulkActions.delete') }}
|
||||
</button>
|
||||
<button
|
||||
@click="showBulkEditModal = true"
|
||||
class="btn btn-primary btn-sm"
|
||||
>
|
||||
<svg class="w-4 h-4 mr-1.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" />
|
||||
</svg>
|
||||
{{ t('admin.accounts.bulkActions.edit') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Accounts Table -->
|
||||
<div class="card overflow-hidden">
|
||||
<DataTable :columns="columns" :data="accounts" :loading="loading">
|
||||
<template #cell-select="{ row }">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="selectedAccountIds.includes(row.id)"
|
||||
@change="toggleAccountSelection(row.id)"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<template #cell-name="{ value }">
|
||||
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
|
||||
</template>
|
||||
@@ -315,6 +378,32 @@
|
||||
@confirm="confirmDelete"
|
||||
@cancel="showDeleteDialog = false"
|
||||
/>
|
||||
<ConfirmDialog
|
||||
:show="showBulkDeleteDialog"
|
||||
:title="t('admin.accounts.bulkDeleteTitle')"
|
||||
:message="t('admin.accounts.bulkDeleteConfirm', { count: selectedAccountIds.length })"
|
||||
:confirm-text="t('common.delete')"
|
||||
:cancel-text="t('common.cancel')"
|
||||
:danger="true"
|
||||
@confirm="confirmBulkDelete"
|
||||
@cancel="showBulkDeleteDialog = false"
|
||||
/>
|
||||
|
||||
<SyncFromCrsModal
|
||||
:show="showCrsSyncModal"
|
||||
@close="showCrsSyncModal = false"
|
||||
@synced="handleCrsSynced"
|
||||
/>
|
||||
|
||||
<!-- Bulk Edit Account Modal -->
|
||||
<BulkEditAccountModal
|
||||
:show="showBulkEditModal"
|
||||
:account-ids="selectedAccountIds"
|
||||
:proxies="proxies"
|
||||
:groups="groups"
|
||||
@close="showBulkEditModal = false"
|
||||
@updated="handleBulkUpdated"
|
||||
/>
|
||||
</AppLayout>
|
||||
</template>
|
||||
|
||||
@@ -331,7 +420,7 @@ import Pagination from '@/components/common/Pagination.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import EmptyState from '@/components/common/EmptyState.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import { CreateAccountModal, EditAccountModal, ReAuthAccountModal, AccountStatsModal } from '@/components/account'
|
||||
import { CreateAccountModal, EditAccountModal, BulkEditAccountModal, ReAuthAccountModal, AccountStatsModal, SyncFromCrsModal } from '@/components/account'
|
||||
import AccountStatusIndicator from '@/components/account/AccountStatusIndicator.vue'
|
||||
import AccountUsageCell from '@/components/account/AccountUsageCell.vue'
|
||||
import AccountTodayStatsCell from '@/components/account/AccountTodayStatsCell.vue'
|
||||
@@ -345,6 +434,7 @@ const appStore = useAppStore()
|
||||
|
||||
// Table columns
|
||||
const columns = computed<Column[]>(() => [
|
||||
{ key: 'select', label: '', sortable: false },
|
||||
{ key: 'name', label: t('admin.accounts.columns.name'), sortable: true },
|
||||
{ key: 'platform_type', label: t('admin.accounts.columns.platformType'), sortable: false },
|
||||
{ key: 'concurrency', label: t('admin.accounts.columns.concurrencyStatus'), sortable: false },
|
||||
@@ -402,14 +492,26 @@ const showCreateModal = ref(false)
|
||||
const showEditModal = ref(false)
|
||||
const showReAuthModal = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const showBulkDeleteDialog = ref(false)
|
||||
const showTestModal = ref(false)
|
||||
const showStatsModal = ref(false)
|
||||
const showCrsSyncModal = ref(false)
|
||||
const showBulkEditModal = ref(false)
|
||||
const editingAccount = ref<Account | null>(null)
|
||||
const reAuthAccount = ref<Account | null>(null)
|
||||
const deletingAccount = ref<Account | null>(null)
|
||||
const testingAccount = ref<Account | null>(null)
|
||||
const statsAccount = ref<Account | null>(null)
|
||||
const togglingSchedulable = ref<number | null>(null)
|
||||
const bulkDeleting = ref(false)
|
||||
|
||||
// Bulk selection
|
||||
const selectedAccountIds = ref<number[]>([])
|
||||
const selectCurrentPageAccounts = () => {
|
||||
const pageIds = accounts.value.map(account => account.id)
|
||||
const merged = new Set([...selectedAccountIds.value, ...pageIds])
|
||||
selectedAccountIds.value = Array.from(merged)
|
||||
}
|
||||
|
||||
// Rate limit / Overload helpers
|
||||
const isRateLimited = (account: Account): boolean => {
|
||||
@@ -480,6 +582,11 @@ const handlePageChange = (page: number) => {
|
||||
loadAccounts()
|
||||
}
|
||||
|
||||
const handleCrsSynced = () => {
|
||||
showCrsSyncModal.value = false
|
||||
loadAccounts()
|
||||
}
|
||||
|
||||
// Edit modal
|
||||
const handleEdit = (account: Account) => {
|
||||
editingAccount.value = account
|
||||
@@ -535,6 +642,38 @@ const confirmDelete = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
const handleBulkDelete = () => {
|
||||
if (selectedAccountIds.value.length === 0) return
|
||||
showBulkDeleteDialog.value = true
|
||||
}
|
||||
|
||||
const confirmBulkDelete = async () => {
|
||||
if (bulkDeleting.value || selectedAccountIds.value.length === 0) return
|
||||
|
||||
bulkDeleting.value = true
|
||||
const ids = [...selectedAccountIds.value]
|
||||
try {
|
||||
const results = await Promise.allSettled(ids.map(id => adminAPI.accounts.delete(id)))
|
||||
const success = results.filter(result => result.status === 'fulfilled').length
|
||||
const failed = results.length - success
|
||||
|
||||
if (failed === 0) {
|
||||
appStore.showSuccess(t('admin.accounts.bulkDeleteSuccess', { count: success }))
|
||||
} else {
|
||||
appStore.showError(t('admin.accounts.bulkDeletePartial', { success, failed }))
|
||||
}
|
||||
|
||||
showBulkDeleteDialog.value = false
|
||||
selectedAccountIds.value = []
|
||||
loadAccounts()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.bulkDeleteFailed'))
|
||||
console.error('Error deleting accounts:', error)
|
||||
} finally {
|
||||
bulkDeleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Clear rate limit
|
||||
const handleClearRateLimit = async (account: Account) => {
|
||||
try {
|
||||
@@ -608,6 +747,23 @@ const closeStatsModal = () => {
|
||||
statsAccount.value = null
|
||||
}
|
||||
|
||||
// Bulk selection toggle
|
||||
const toggleAccountSelection = (accountId: number) => {
|
||||
const index = selectedAccountIds.value.indexOf(accountId)
|
||||
if (index === -1) {
|
||||
selectedAccountIds.value.push(accountId)
|
||||
} else {
|
||||
selectedAccountIds.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Bulk update handler
|
||||
const handleBulkUpdated = () => {
|
||||
showBulkEditModal.value = false
|
||||
selectedAccountIds.value = []
|
||||
loadAccounts()
|
||||
}
|
||||
|
||||
// Initialize
|
||||
onMounted(() => {
|
||||
loadAccounts()
|
||||
|
||||
Reference in New Issue
Block a user